diff --git a/backend/groth16/bn254/solidity.go b/backend/groth16/bn254/solidity.go index fb3c2fe7b8..5f0b1a504e 100644 --- a/backend/groth16/bn254/solidity.go +++ b/backend/groth16/bn254/solidity.go @@ -541,17 +541,36 @@ contract Verifier { } {{- end}} (uint256 Px, uint256 Py) = decompress_g1(compressedCommitmentPok); + + uint256[] memory publicAndCommitmentCommitted; {{- range $i := intRange $numCommitments }} + {{- $pcIndex := index $PublicAndCommitmentCommitted $i }} + {{- if gt (len $pcIndex) 0 }} + publicAndCommitmentCommitted = new uint256[]({{(len $pcIndex)}}); + assembly ("memory-safe") { + let publicAndCommitmentCommittedOffset := add(publicAndCommitmentCommitted, 0x20) + {{- $segment_start := index $pcIndex 0 }} + {{- $segment_end := index $pcIndex 0 }} + {{- $l := 0 }} + {{- range $k := intRange (sub (len $pcIndex) 1) }} + {{- $next := index $pcIndex (sum $k 1) }} + {{- if ne $next (sum $segment_end 1) }} + calldatacopy(add(publicAndCommitmentCommittedOffset, {{mul $l 0x20}}), add(input, {{mul 0x20 (sub $segment_start 1)}}), {{mul 0x20 (sum 1 (sub $segment_end $segment_start))}}) + {{- $segment_start = $next }} + {{- $l = (sum $k 1) }} + {{- end }} + {{- $segment_end = $next }} + {{- end }} + calldatacopy(add(publicAndCommitmentCommittedOffset, {{mul $l 0x20}}), add(input, {{mul 0x20 (sub $segment_start 1)}}), {{mul 0x20 (sum 1 (sub $segment_end $segment_start))}}) + } + {{- end }} + publicCommitments[{{$i}}] = uint256( sha256( abi.encodePacked( commitments[{{mul $i 2}}], - commitments[{{sum (mul $i 2) 1}}] - {{- $pcIndex := index $PublicAndCommitmentCommitted $i }} - {{- range $j := intRange (len $pcIndex) }} - {{- $l := index $pcIndex $j }} - ,input[{{sub $l 1}}] - {{- end }} + commitments[{{sum (mul $i 2) 1}}], + publicAndCommitmentCommitted ) ) ) % R; @@ -670,17 +689,35 @@ contract Verifier { {{- else }} // HashToField uint256[{{$numCommitments}}] memory publicCommitments; + uint256[] memory publicAndCommitmentCommitted; {{- range $i := intRange $numCommitments }} + {{- $pcIndex := index $PublicAndCommitmentCommitted $i }} + {{- if gt (len $pcIndex) 0 }} + publicAndCommitmentCommitted = new uint256[]({{(len $pcIndex)}}); + assembly ("memory-safe") { + let publicAndCommitmentCommittedOffset := add(publicAndCommitmentCommitted, 0x20) + {{- $segment_start := index $pcIndex 0 }} + {{- $segment_end := index $pcIndex 0 }} + {{- $l := 0 }} + {{- range $k := intRange (sub (len $pcIndex) 1) }} + {{- $next := index $pcIndex (sum $k 1) }} + {{- if ne $next (sum $segment_end 1) }} + calldatacopy(add(publicAndCommitmentCommittedOffset, {{mul $l 0x20}}), add(input, {{mul 0x20 (sub $segment_start 1)}}), {{mul 0x20 (sum 1 (sub $segment_end $segment_start))}}) + {{- $segment_start = $next }} + {{- $l = (sum $k 1) }} + {{- end }} + {{- $segment_end = $next }} + {{- end }} + calldatacopy(add(publicAndCommitmentCommittedOffset, {{mul $l 0x20}}), add(input, {{mul 0x20 (sub $segment_start 1)}}), {{mul 0x20 (sum 1 (sub $segment_end $segment_start))}}) + } + {{- end }} + publicCommitments[{{$i}}] = uint256( sha256( abi.encodePacked( commitments[{{mul $i 2}}], - commitments[{{sum (mul $i 2) 1}}] - {{- $pcIndex := index $PublicAndCommitmentCommitted $i }} - {{- range $j := intRange (len $pcIndex) }} - {{- $l := index $pcIndex $j }} - ,input[{{sub $l 1}}] - {{- end }} + commitments[{{sum (mul $i 2) 1}}], + publicAndCommitmentCommitted ) ) ) % R; diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index f75158e9f8..1b615664ee 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -322,7 +322,7 @@ func NewVerifyingKey(curveID ecc.ID) VerifyingKey { // SRSSize returns the required size of the kzg SRS for a given constraint system // Note that the SRS size in Lagrange form is a power of 2, // and the SRS size in canonical form need few extra elements (3) to account for the blinding factors -func SRSSize(ccs constraint.System) (sizeCanonical, sizeLagrange int) { +func SRSSize(ccs constraint.ConstraintSystem) (sizeCanonical, sizeLagrange int) { nbConstraints := ccs.GetNbConstraints() sizeSystem := nbConstraints + ccs.GetNbPublicVariables() diff --git a/std/evmprecompiles/05-expmod.go b/std/evmprecompiles/05-expmod.go index 6b1eb16123..442ba88095 100644 --- a/std/evmprecompiles/05-expmod.go +++ b/std/evmprecompiles/05-expmod.go @@ -1 +1,31 @@ package evmprecompiles + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" +) + +// Expmod implements [MODEXP] precompile contract at address 0x05. +// +// Internally, uses 4k elements for representing the base, exponent and modulus, +// upper bounding the sizes of the inputs. The runtime is constant regardless of +// the actual length of the inputs. +// +// [MODEXP]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/expmod/index.html +func Expmod(api frontend.API, base, exp, modulus *emulated.Element[emparams.Mod1e4096]) *emulated.Element[emparams.Mod1e4096] { + // x^0 = 1 + // x mod 0 = 0 + f, err := emulated.NewField[emparams.Mod1e4096](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + // in case modulus is zero, then need to compute with dummy values and return zero as a result + isZeroMod := f.IsZero(modulus) + modulus = f.Select(isZeroMod, f.One(), modulus) + res := f.ModExp(base, exp, modulus) + res = f.Select(isZeroMod, f.Zero(), res) + return res +} diff --git a/std/evmprecompiles/05-expmod_test.go b/std/evmprecompiles/05-expmod_test.go new file mode 100644 index 0000000000..5de7d95bcb --- /dev/null +++ b/std/evmprecompiles/05-expmod_test.go @@ -0,0 +1,86 @@ +package evmprecompiles + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type expmodCircuit struct { + Base emulated.Element[emparams.Mod1e4096] + Exp emulated.Element[emparams.Mod1e4096] + Mod emulated.Element[emparams.Mod1e4096] + Result emulated.Element[emparams.Mod1e4096] + edgeCases bool +} + +func (c *expmodCircuit) Define(api frontend.API) error { + res := Expmod(api, &c.Base, &c.Exp, &c.Mod) + f, err := emulated.NewField[emparams.Mod1e4096](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + if c.edgeCases { + // cannot use ModAssertIsEqual for edge cases. But the output is either + // 0 or 1 so can use AssertIsEqual + f.AssertIsEqual(res, &c.Result) + } else { + // for random case need to use ModAssertIsEqual + f.ModAssertIsEqual(&c.Result, res, &c.Mod) + } + return nil +} + +func testInstance(edgeCases bool, base, exp, modulus, result *big.Int) error { + circuit := &expmodCircuit{edgeCases: edgeCases} + assignment := &expmodCircuit{ + Base: emulated.ValueOf[emparams.Mod1e4096](base), + Exp: emulated.ValueOf[emparams.Mod1e4096](exp), + Mod: emulated.ValueOf[emparams.Mod1e4096](modulus), + Result: emulated.ValueOf[emparams.Mod1e4096](result), + } + return test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) +} + +func TestRandomInstance(t *testing.T) { + assert := test.NewAssert(t) + for _, bits := range []int{256, 512, 1024, 2048, 4096} { + assert.Run(func(assert *test.Assert) { + modulus := new(big.Int).Lsh(big.NewInt(1), uint(bits)) + base, _ := rand.Int(rand.Reader, modulus) + exp, _ := rand.Int(rand.Reader, modulus) + res := new(big.Int).Exp(base, exp, modulus) + err := testInstance(false, base, exp, modulus, res) + assert.NoError(err) + }, fmt.Sprintf("random-%d", bits)) + } +} + +func TestEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + testCases := []struct { + base, exp, modulus, result *big.Int + }{ + {big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}, // 0^0 = 0 mod 0 + {big.NewInt(0), big.NewInt(0), big.NewInt(1), big.NewInt(1)}, // 0^0 = 1 mod 1 + {big.NewInt(0), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 0^0 = 1 mod 123 + {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 0 + {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 1 + {big.NewInt(0), big.NewInt(123), big.NewInt(123), big.NewInt(0)}, // 0^123 = 0 mod 123 + {big.NewInt(123), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 123^0 = 1 mod 123 + + } + for i, tc := range testCases { + assert.Run(func(assert *test.Assert) { + err := testInstance(true, tc.base, tc.exp, tc.modulus, tc.result) + assert.NoError(err) + }, fmt.Sprintf("edge-%d", i)) + } +} diff --git a/std/evmprecompiles/doc.go b/std/evmprecompiles/doc.go index 7c515eaa51..9b7dc431a8 100644 --- a/std/evmprecompiles/doc.go +++ b/std/evmprecompiles/doc.go @@ -7,7 +7,7 @@ // 2. SHA256 ❌ -- in progress // 3. RIPEMD160 ❌ -- postponed // 4. ID ❌ -- trivial to implement without function -// 5. EXPMOD ❌ -- in progress +// 5. EXPMOD ✅ -- function [Expmod] // 6. BN_ADD ✅ -- function [ECAdd] // 7. BN_MUL ✅ -- function [ECMul] // 8. SNARKV ✅ -- function [ECPair] diff --git a/std/math/emulated/composition.go b/std/math/emulated/composition.go index 79b0216a80..bce35e7442 100644 --- a/std/math/emulated/composition.go +++ b/std/math/emulated/composition.go @@ -66,10 +66,7 @@ func decompose(input *big.Int, nbBits uint, res []*big.Int) error { // // then no such underflow happens and s = a-b (mod p) as the padding is multiple // of p. -func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int { - var fp T - p := fp.Modulus() - bitsPerLimbs := fp.BitsPerLimb() +func subPadding(modulus *big.Int, bitsPerLimbs uint, overflow uint, nbLimbs uint) []*big.Int { // first, we build a number nLimbs, such that nLimbs > b; // here b is defined by its bounds, that is b is an element with nbLimbs of (bitsPerLimbs+overflow) @@ -86,8 +83,8 @@ func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int { panic(fmt.Sprintf("recompose: %v", err)) } // mod reduce n, and negate it - n.Mod(n, p) - n.Sub(p, n) + n.Mod(n, modulus) + n.Sub(modulus, n) // construct pad such that: // pad := n - neg(n mod p) == kp diff --git a/std/math/emulated/composition_test.go b/std/math/emulated/composition_test.go index d1c59289cc..25ef0430f9 100644 --- a/std/math/emulated/composition_test.go +++ b/std/math/emulated/composition_test.go @@ -54,7 +54,7 @@ func testSubPadding[T FieldParams](t *testing.T) { assert := test.NewAssert(t) for i := fp.NbLimbs(); i < 2*fp.NbLimbs(); i++ { assert.Run(func(assert *test.Assert) { - limbs := subPadding[T](0, i) + limbs := subPadding(fp.Modulus(), fp.BitsPerLimb(), 0, i) padValue := new(big.Int) if err := recompose(limbs, fp.BitsPerLimb(), padValue); err != nil { assert.FailNow("recompose", err) diff --git a/std/math/emulated/custommod.go b/std/math/emulated/custommod.go new file mode 100644 index 0000000000..2f5cbaca1b --- /dev/null +++ b/std/math/emulated/custommod.go @@ -0,0 +1,99 @@ +package emulated + +import ( + "errors" + + "github.com/consensys/gnark/frontend" +) + +// ModMul computes a*b mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit a, +// b and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModMul(a, b *Element[T], modulus *Element[T]) *Element[T] { + res := f.mulMod(a, b, 0, modulus) + return res +} + +// ModAdd computes a+b mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit a, +// b and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModAdd(a, b *Element[T], modulus *Element[T]) *Element[T] { + // inlined version of [Field.reduceAndOp] which uses variable-modulus reduction + var nextOverflow uint + var err error + var target overflowError + for nextOverflow, err = f.addPreCond(a, b); errors.As(err, &target); nextOverflow, err = f.addPreCond(a, b) { + if errors.As(err, &target) { + if !target.reduceRight { + a = f.mulMod(a, f.shortOne(), 0, modulus) + } else { + b = f.mulMod(b, f.shortOne(), 0, modulus) + } + } + } + res := f.add(a, b, nextOverflow) + return res +} + +func (f *Field[T]) modSub(a, b *Element[T], modulus *Element[T]) *Element[T] { + // like fixed modulus subtraction, but for sub padding need to use hint + // instead of assuming T as a constant. And when doing as a hint, then need + // to assert that the padding is a multiple of the modulus (done inside callSubPaddingHint) + nextOverflow := max(b.overflow+1, a.overflow) + 1 + nbLimbs := max(len(a.Limbs), len(b.Limbs)) + limbs := make([]frontend.Variable, nbLimbs) + padding := f.computeSubPaddingHint(b.overflow, uint(nbLimbs), modulus) + for i := range limbs { + limbs[i] = padding.Limbs[i] + if i < len(a.Limbs) { + limbs[i] = f.api.Add(limbs[i], a.Limbs[i]) + } + if i < len(b.Limbs) { + limbs[i] = f.api.Sub(limbs[i], b.Limbs[i]) + } + } + res := f.newInternalElement(limbs, nextOverflow) + return res +} + +// ModAssertIsEqual asserts equality of a and b mod modulus. Instead of taking +// modulus as a constant parametrized by T, it is passed as an argument. This +// allows to use a variable modulus in the circuit. Type parameter T should be +// sufficiently big to fit a, b and modulus. Recommended to use +// [emparams.Mod1e512] or [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModAssertIsEqual(a, b *Element[T], modulus *Element[T]) { + // like fixed modulus AssertIsEqual, but uses current Sub implementation for + // computing the diff + diff := f.modSub(b, a, modulus) + f.checkZero(diff, modulus) +} + +// ModExp computes base^exp mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit +// base, exp and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModExp(base, exp, modulus *Element[T]) *Element[T] { + expBts := f.ToBits(exp) + n := len(expBts) + res := f.Select(expBts[0], base, f.One()) + base = f.ModMul(base, base, modulus) + for i := 1; i < n-1; i++ { + res = f.Select(expBts[i], f.ModMul(base, res, modulus), res) + base = f.ModMul(base, base, modulus) + } + res = f.Select(expBts[n-1], f.ModMul(base, res, modulus), res) + return res +} diff --git a/std/math/emulated/custommod_test.go b/std/math/emulated/custommod_test.go new file mode 100644 index 0000000000..a399769322 --- /dev/null +++ b/std/math/emulated/custommod_test.go @@ -0,0 +1,171 @@ +package emulated + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type variableEquality[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] +} + +func (c *variableEquality[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + f.ModAssertIsEqual(&c.A, &c.B, &c.Modulus) + return nil +} + +func TestVariableEquality(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + a := big.NewInt(10) + b := new(big.Int).Add(a, modulus) + circuit := &variableEquality[emparams.Mod1e512]{} + assignment := &variableEquality[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](a), + B: ValueOf[emparams.Mod1e512](b), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableAddition[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableAddition[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModAdd(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableAddition(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + circuit := &variableAddition[emparams.Mod1e512]{} + assignment := &variableAddition[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](10), + B: ValueOf[emparams.Mod1e512](20), + Expected: ValueOf[emparams.Mod1e512](30), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableSubtraction[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableSubtraction[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.modSub(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableSubtraction(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + circuit := &variableSubtraction[emparams.Mod1e512]{} + res := new(big.Int).Sub(modulus, big.NewInt(10)) + assignment := &variableSubtraction[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](10), + B: ValueOf[emparams.Mod1e512](20), + Expected: ValueOf[emparams.Mod1e512](res), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableMultiplication[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableMultiplication[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModMul(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableMultiplication(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + a, _ := rand.Int(rand.Reader, modulus) + b, _ := rand.Int(rand.Reader, modulus) + exp := new(big.Int).Mul(a, b) + exp.Mod(exp, modulus) + circuit := &variableMultiplication[emparams.Mod1e512]{} + assignment := &variableMultiplication[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](a), + B: ValueOf[emparams.Mod1e512](b), + Expected: ValueOf[emparams.Mod1e512](exp), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableExp[T FieldParams] struct { + Modulus Element[T] + Base Element[T] + Exp Element[T] + Expected Element[T] +} + +func (c *variableExp[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModExp(&c.Base, &c.Exp, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableExp(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + base, _ := rand.Int(rand.Reader, modulus) + exp, _ := rand.Int(rand.Reader, modulus) + expected := new(big.Int).Exp(base, exp, modulus) + circuit := &variableExp[emparams.Mod1e512]{} + assignment := &variableExp[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + Base: ValueOf[emparams.Mod1e512](base), + Exp: ValueOf[emparams.Mod1e512](exp), + Expected: ValueOf[emparams.Mod1e512](expected), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index 61a6e54288..6f4685f253 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -184,5 +184,17 @@ The package currently does not explicitly differentiate between constant and variable elements. The builder may track some elements as being constants. Some operations have a fast track path for cases when all inputs are constants. There is [Field.MulConst], which provides variable by constant multiplication. + +# Variable-modulus operations + +The package also exposes methods for performing operations with variable +modulus. The modulus is represented as an element and is not required to be +prime. The methods for variable-modulus operations are [Field.ModMul], +[Field.ModAdd], [Field.ModExp] and [Field.ModAssertIsEqual]. The modulus is +passed as an argument to the operation. + +The type parameter for the [Field] should be sufficiently big to allow to fit +the inputs and the modulus. Recommended to use predefined [emparams.Mod1e512] or +[emparams.Mod1e4096]. */ package emulated diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 8954fc4d69..675f296596 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/test" ) @@ -1059,3 +1060,41 @@ func testSum[T FieldParams](t *testing.T) { assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) }, testName[T]()) } + +type expCircuit[T FieldParams] struct { + Base Element[T] + Exp Element[T] + Expected Element[T] +} + +func (c *expCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.Exp(&c.Base, &c.Exp) + f.AssertIsEqual(&c.Expected, res) + return nil +} + +func testExp[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit expCircuit[T] + base, _ := rand.Int(rand.Reader, fp.Modulus()) + exp, _ := rand.Int(rand.Reader, fp.Modulus()) + expected := new(big.Int).Exp(base, exp, fp.Modulus()) + assignment := &expCircuit[T]{ + Base: ValueOf[T](base), + Exp: ValueOf[T](exp), + Expected: ValueOf[T](expected), + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(assignment)) + }, testName[T]()) +} +func TestExp(t *testing.T) { + testExp[Goldilocks](t) + testExp[BN254Fr](t) + testExp[emparams.Mod1e512](t) +} diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go index b07fb6e96b..bbee1afabf 100644 --- a/std/math/emulated/emparams/emparams.go +++ b/std/math/emulated/emparams/emparams.go @@ -281,3 +281,41 @@ func (fp BLS24315Fp) Modulus() *big.Int { return ecc.BLS24_315.BaseField() } type BLS24315Fr struct{ fourLimbPrimeField } func (fr BLS24315Fr) Modulus() *big.Int { return ecc.BLS24_315.ScalarField() } + +// Mod1e4096 provides type parametrization for emulated aritmetic: +// - limbs: 64 +// - limb width: 64 bits +// +// The modulus for type parametrisation is 2^4096-1. +// +// This is non-prime modulus. It is mainly targeted for using variable-modulus +// operations (ModAdd, ModMul, ModExp, ModAssertIsEqual) for variable modulus +// arithmetic. +type Mod1e4096 struct{} + +func (Mod1e4096) NbLimbs() uint { return 64 } +func (Mod1e4096) BitsPerLimb() uint { return 64 } +func (Mod1e4096) IsPrime() bool { return false } +func (Mod1e4096) Modulus() *big.Int { + val, _ := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + return val +} + +// Mod1e512 provides type parametrization for emulated aritmetic: +// - limbs: 8 +// - limb width: 64 bits +// +// The modulus for type parametrisation is 2^512-1. +// +// This is non-prime modulus. It is mainly targeted for using variable-modulus +// operations (ModAdd, ModMul, ModExp, ModAssertIsEqual) for variable modulus +// arithmetic. +type Mod1e512 struct{} + +func (Mod1e512) NbLimbs() uint { return 8 } +func (Mod1e512) BitsPerLimb() uint { return 64 } +func (Mod1e512) IsPrime() bool { return false } +func (Mod1e512) Modulus() *big.Int { + val, _ := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + return val +} diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index a2809e4eb9..5c2c700663 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -46,7 +46,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } diff := f.Sub(b, a) - f.checkZero(diff) + f.checkZero(diff, nil) } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 9a2671d08a..278b9a5024 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -58,21 +58,25 @@ type mulCheck[T FieldParams] struct { r *Element[T] // reduced value k *Element[T] // coefficient c *Element[T] // carry + p *Element[T] // modulus if non-nil } // evalRound1 evaluates first c(X), r(X) and k(X) at a given random point at[0]. // In the first round we do not assume that any of them is already evaluated as // they come directly from hint. -func (mc *mulCheck[T]) evalRound1(api frontend.API, at []frontend.Variable) { +func (mc *mulCheck[T]) evalRound1(at []frontend.Variable) { mc.c = mc.f.evalWithChallenge(mc.c, at) mc.r = mc.f.evalWithChallenge(mc.r, at) mc.k = mc.f.evalWithChallenge(mc.k, at) + if mc.p != nil { + mc.p = mc.f.evalWithChallenge(mc.p, at) + } } // evalRound2 now evaluates a and b at a given random point at[0]. However, it // may happen that a or b is equal to r from a previous mulcheck. In that case // we can reuse the evaluation to save constraints. -func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { +func (mc *mulCheck[T]) evalRound2(at []frontend.Variable) { mc.a = mc.f.evalWithChallenge(mc.a, at) mc.b = mc.f.evalWithChallenge(mc.b, at) } @@ -81,6 +85,9 @@ func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { // computation of p(ch) and (2^t-ch) can be shared over all mulCheck instances, // then we get them already evaluated as peval and coef. func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { + if mc.p != nil { + peval = mc.p.evaluation + } ls := api.Mul(mc.a.evaluation, mc.b.evaluation) rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) api.AssertIsEqual(ls, rs) @@ -99,14 +106,19 @@ func (mc *mulCheck[T]) cleanEvaluations() { mc.k.isEvaluated = false mc.c.evaluation = 0 mc.c.isEvaluated = false + if mc.p != nil { + mc.p.evaluation = 0 + mc.p.isEvaluated = false + } } // mulMod returns a*b mod r. In practice it computes the result using a hint and // defers the actual multiplication check. -func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { +func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) - k, r, c, err := f.callMulHint(a, b, true) + f.enforceWidthConditional(p) + k, r, c, err := f.callMulHint(a, b, true, p) if err != nil { panic(err) } @@ -117,18 +129,20 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { c: c, k: k, r: r, + p: p, } f.mulChecks = append(f.mulChecks, mc) return r } // checkZero creates multiplication check a * 1 = 0 + k*p. -func (f *Field[T]) checkZero(a *Element[T]) { +func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { // the method works similarly to mulMod, but we know that we are multiplying // by one and expected result should be zero. f.enforceWidthConditional(a) + f.enforceWidthConditional(p) b := f.shortOne() - k, r, c, err := f.callMulHint(a, b, false) + k, r, c, err := f.callMulHint(a, b, false, p) if err != nil { panic(err) } @@ -139,6 +153,7 @@ func (f *Field[T]) checkZero(a *Element[T]) { c: c, k: k, r: r, // expected to be zero on zero limbs. + p: p, } f.mulChecks = append(f.mulChecks, mc) } @@ -191,6 +206,9 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { toCommit = append(toCommit, f.mulChecks[i].r.Limbs...) toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) + if f.mulChecks[i].p != nil { + toCommit = append(toCommit, f.mulChecks[i].p.Limbs...) + } } // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { @@ -207,11 +225,11 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // evaluate all r, k, c for i := range f.mulChecks { - f.mulChecks[i].evalRound1(api, at) + f.mulChecks[i].evalRound1(at) } // assuming r is input to some other multiplication, then is already evaluated for i := range f.mulChecks { - f.mulChecks[i].evalRound2(api, at) + f.mulChecks[i].evalRound2(at) } // evaluate p(X) at challenge pval := f.evalWithChallenge(f.Modulus(), at) @@ -234,7 +252,7 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // callMulHint uses hint to compute r, k and c. -func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carries *Element[T], err error) { +func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool, customMod *Element[T]) (quo, rem, carries *Element[T], err error) { // compute the expected overflow after the multiplication of a*b to be able // to estimate the number of bits required to represent the result. nextOverflow, _ := f.mulPreCond(a, b) @@ -249,8 +267,15 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carri // we compute the width of the product of a*b, then we divide it by the // width of the modulus. We add 1 to the result to ensure that we have // enough space for the quotient. + modbits := uint(f.fParams.Modulus().BitLen()) + if customMod != nil { + // when we're using custom modulus, then we do not really know its + // length ahead of time. We assume worst case scenario and assume that + // the quotient can be the total length of the multiplication result. + modbits = 0 + } nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // - uint(f.fParams.Modulus().BitLen()) + // + modbits + // nbBits - 1) / nbBits // the remainder is always less than modulus so can represent on the same @@ -267,7 +292,11 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carri len(a.Limbs), nbQuoLimbs, } - hintInputs = append(hintInputs, f.Modulus().Limbs...) + modulusLimbs := f.Modulus().Limbs + if customMod != nil { + modulusLimbs = customMod.Limbs + } + hintInputs = append(hintInputs, modulusLimbs...) hintInputs = append(hintInputs, a.Limbs...) hintInputs = append(hintInputs, b.Limbs...) ret, err := f.api.NewHint(mulHint, int(nbQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) @@ -328,7 +357,9 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { quo := new(big.Int) rem := new(big.Int) ab := new(big.Int).Mul(a, b) - quo.QuoRem(ab, p, rem) + if p.Cmp(new(big.Int)) != 0 { + quo.QuoRem(ab, p, rem) + } if err := decompose(quo, uint(nbBits), quoLimbs); err != nil { return fmt.Errorf("decompose quo: %w", err) } @@ -380,7 +411,7 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { // For multiplying by a constant, use [Field[T].MulConst] method which is more // efficient. func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { - return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) + return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b) } // MulMod computes a*b and reduces it modulo the field order. The returned Element @@ -388,7 +419,7 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { // // Equivalent to [Field[T].Mul], kept for backwards compatibility. func (f *Field[T]) MulMod(a, b *Element[T]) *Element[T] { - return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) + return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b) } // MulConst multiplies a by a constant c and returns it. We assume that the @@ -463,3 +494,18 @@ func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] } return f.newInternalElement(resLimbs, nextoverflow) } + +// Exp computes base^exp modulo the field order. The returned Element has default +// number of limbs and zero overflow. +func (f *Field[T]) Exp(base, exp *Element[T]) *Element[T] { + expBts := f.ToBits(exp) + n := len(expBts) + res := f.Select(expBts[0], base, f.One()) + base = f.Mul(base, base) + for i := 1; i < n-1; i++ { + res = f.Select(expBts[i], f.Mul(base, res), res) + base = f.Mul(base, base) + } + res = f.Select(expBts[n-1], f.Mul(base, res), res) + return res +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index aeaf2c3059..a9f0d9cda3 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -176,7 +176,7 @@ func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { panic("trying to reduce a constant, which happen to have an overflow flag set") } // slow path - use hint to reduce value - return f.mulMod(a, f.One(), 0) + return f.mulMod(a, f.One(), 0, nil) } // Sub subtracts b from a and returns it. Reduces locally if wouldn't fit into @@ -204,9 +204,10 @@ func (f *Field[T]) sub(a, b *Element[T], nextOverflow uint) *Element[T] { // first we have to compute padding to ensure that the subtraction does not // underflow. + var fp T nbLimbs := max(len(a.Limbs), len(b.Limbs)) limbs := make([]frontend.Variable, nbLimbs) - padLimbs := subPadding[T](b.overflow, uint(nbLimbs)) + padLimbs := subPadding(fp.Modulus(), fp.BitsPerLimb(), b.overflow, uint(nbLimbs)) for i := range limbs { limbs[i] = padLimbs[i] if i < len(a.Limbs) { diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 6c1644c407..eab14b47e9 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -22,6 +22,7 @@ func GetHints() []solver.Hint { InverseHint, SqrtHint, mulHint, + subPaddingHint, } } @@ -153,3 +154,51 @@ func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil }) } + +// subPaddingHint computes the padding for the subtraction of two numbers. It +// ensures that the padding is a multiple of the modulus. Can be used to avoid +// underflow. +// +// In case of fixed modulus use subPadding instead. +func subPaddingHint(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) < 4 { + return fmt.Errorf("input must be at least four elements") + } + nbLimbs := int(inputs[0].Int64()) + bitsPerLimbs := uint(inputs[1].Uint64()) + overflow := uint(inputs[2].Uint64()) + retLimbs := int(inputs[3].Int64()) + if len(inputs[4:]) != nbLimbs { + return fmt.Errorf("input length mismatch") + } + if len(outputs) != retLimbs { + return fmt.Errorf("result does not fit into output") + } + pLimbs := inputs[4 : 4+nbLimbs] + p := new(big.Int) + if err := recompose(pLimbs, bitsPerLimbs, p); err != nil { + return fmt.Errorf("recompose modulus: %w", err) + } + padLimbs := subPadding(p, bitsPerLimbs, overflow, uint(nbLimbs)) + for i := range padLimbs { + outputs[i].Set(padLimbs[i]) + } + + return nil +} + +func (f *Field[T]) computeSubPaddingHint(overflow uint, nbLimbs uint, modulus *Element[T]) *Element[T] { + var fp T + inputs := []frontend.Variable{fp.NbLimbs(), fp.BitsPerLimb(), overflow, nbLimbs} + inputs = append(inputs, modulus.Limbs...) + res, err := f.api.NewHint(subPaddingHint, int(nbLimbs), inputs...) + if err != nil { + panic(fmt.Sprintf("sub padding hint: %v", err)) + } + for i := range res { + f.checker.Check(res[i], int(fp.BitsPerLimb()+overflow+1)) + } + padding := f.newInternalElement(res, fp.BitsPerLimb()+overflow+1) + f.checkZero(padding, modulus) + return padding +} diff --git a/std/math/polynomial/polynomial_test.go b/std/math/polynomial/polynomial_test.go index ccd2c08b64..4fd2929533 100644 --- a/std/math/polynomial/polynomial_test.go +++ b/std/math/polynomial/polynomial_test.go @@ -42,7 +42,7 @@ func testEvalPoly[FR emulated.FieldParams](t *testing.T, p []int64, at int64, ev Evaluation: emulated.ValueOf[FR](evaluation), } - assert.CheckCircuit(&evalPolyCircuit[FR]{P: make([]emulated.Element[FR], len(p))}, test.WithValidAssignment(&witness), test.NoSolidityChecks()) + assert.CheckCircuit(&evalPolyCircuit[FR]{P: make([]emulated.Element[FR], len(p))}, test.WithValidAssignment(&witness)) } func TestEvalPoly(t *testing.T) { @@ -97,7 +97,7 @@ func testEvalMultiLin[FR emulated.FieldParams](t *testing.T) { Evaluation: emulated.ValueOf[FR](17), } - assert.CheckCircuit(&evalMultiLinCircuit[FR]{M: make([]emulated.Element[FR], 4), At: make([]emulated.Element[FR], 2)}, test.WithValidAssignment(&witness), test.NoSolidityChecks()) + assert.CheckCircuit(&evalMultiLinCircuit[FR]{M: make([]emulated.Element[FR], 4), At: make([]emulated.Element[FR], 2)}, test.WithValidAssignment(&witness)) } type evalEqCircuit[FR emulated.FieldParams] struct { @@ -143,7 +143,7 @@ func testEvalEq[FR emulated.FieldParams](t *testing.T) { Eq: emulated.ValueOf[FR](148665), } - assert.CheckCircuit(&evalEqCircuit[FR]{X: make([]emulated.Element[FR], 4), Y: make([]emulated.Element[FR], 4)}, test.WithValidAssignment(&witness), test.NoSolidityChecks()) + assert.CheckCircuit(&evalEqCircuit[FR]{X: make([]emulated.Element[FR], 4), Y: make([]emulated.Element[FR], 4)}, test.WithValidAssignment(&witness)) } type interpolateLDECircuit[FR emulated.FieldParams] struct { @@ -179,7 +179,7 @@ func testInterpolateLDE[FR emulated.FieldParams](t *testing.T, at int64, values Expected: emulated.ValueOf[FR](expected), } - assert.CheckCircuit(&interpolateLDECircuit[FR]{Values: make([]emulated.Element[FR], len(values))}, test.WithValidAssignment(assignment), test.NoSolidityChecks()) + assert.CheckCircuit(&interpolateLDECircuit[FR]{Values: make([]emulated.Element[FR], len(values))}, test.WithValidAssignment(assignment)) } func TestInterpolateLDEOnRange(t *testing.T) {