diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dd5d76775..efd09efac4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,10 +19,15 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased](https://github.com/o1-labs/o1js/compare/26363465d...HEAD) +### Breaking changes + +- Change return signature of `ZkProgram.analyzeMethods()` to be a keyed object https://github.com/o1-labs/o1js/pull/1223 + ### Added - Provable non-native field arithmetic: - `Gadgets.ForeignField.{add, sub, sumchain}()` for addition and subtraction https://github.com/o1-labs/o1js/pull/1220 + - `Gadgets.ForeignField.{mul, inv, div}()` for multiplication and division https://github.com/o1-labs/o1js/pull/1223 - Comprehensive internal testing of constraint system layouts generated by new gadgets https://github.com/o1-labs/o1js/pull/1241 https://github.com/o1-labs/o1js/pull/1220 ### Changed diff --git a/src/bindings b/src/bindings index cea062267c..c8f8c631f2 160000 --- a/src/bindings +++ b/src/bindings @@ -1 +1 @@ -Subproject commit cea062267c2cf81edf50fee8ca9578824c056731 +Subproject commit c8f8c631f28b84c3d3859378a2fe857091207755 diff --git a/src/lib/gadgets/common.ts b/src/lib/gadgets/common.ts index acffca9349..196ca64e73 100644 --- a/src/lib/gadgets/common.ts +++ b/src/lib/gadgets/common.ts @@ -60,7 +60,7 @@ function toVars>( return Tuple.map(fields, toVar); } -function assert(stmt: boolean, message?: string) { +function assert(stmt: boolean, message?: string): asserts stmt { if (!stmt) { throw Error(message ?? 'Assertion failed'); } diff --git a/src/lib/gadgets/foreign-field.ts b/src/lib/gadgets/foreign-field.ts index 189ffcc3f8..75f41a7f82 100644 --- a/src/lib/gadgets/foreign-field.ts +++ b/src/lib/gadgets/foreign-field.ts @@ -1,10 +1,21 @@ -import { mod } from '../../bindings/crypto/finite_field.js'; +import { + inverse as modInverse, + mod, +} from '../../bindings/crypto/finite_field.js'; import { provableTuple } from '../../bindings/lib/provable-snarky.js'; import { Field } from '../field.js'; import { Gates, foreignFieldAdd } from '../gates.js'; import { Tuple } from '../util/types.js'; -import { assert, exists, toVars } from './common.js'; -import { L, lMask, multiRangeCheck, twoL, twoLMask } from './range-check.js'; +import { assert, bitSlice, exists, toVars } from './common.js'; +import { + l, + lMask, + multiRangeCheck, + l2, + l2Mask, + l3, + compactMultiRangeCheck, +} from './range-check.js'; export { ForeignField, Field3, Sign }; @@ -23,6 +34,10 @@ const ForeignField = { return sum([x, y], [-1n], f); }, sum, + + mul: multiply, + inv: inverse, + div: divide, }; /** @@ -70,7 +85,7 @@ function singleAdd(x: Field3, y: Field3, sign: Sign, f: bigint) { let y_ = toBigint3(y); // figure out if there's overflow - let r = collapse(x_) + sign * collapse(y_); + let r = combine(x_) + sign * combine(y_); let overflow = 0n; if (sign === 1n && r >= f) overflow = 1n; if (sign === -1n && r < 0n) overflow = -1n; @@ -78,9 +93,9 @@ function singleAdd(x: Field3, y: Field3, sign: Sign, f: bigint) { // do the add with carry // note: this "just works" with negative r01 - let r01 = collapse2(x_) + sign * collapse2(y_) - overflow * collapse2(f_); - let carry = r01 >> twoL; - r01 &= twoLMask; + let r01 = combine2(x_) + sign * combine2(y_) - overflow * combine2(f_); + let carry = r01 >> l2; + r01 &= l2Mask; let [r0, r1] = split2(r01); let r2 = x_[2] + sign * y_[2] - overflow * f_[2] + carry; @@ -92,19 +107,238 @@ function singleAdd(x: Field3, y: Field3, sign: Sign, f: bigint) { return { result: [r0, r1, r2] satisfies Field3, overflow }; } +function multiply(a: Field3, b: Field3, f: bigint): Field3 { + assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); + + // constant case + if (a.every((x) => x.isConstant()) && b.every((x) => x.isConstant())) { + let ab = Field3.toBigint(a) * Field3.toBigint(b); + return Field3.from(mod(ab, f)); + } + + // provable case + let { r01, r2, q } = multiplyNoRangeCheck(a, b, f); + + // limb range checks on quotient and remainder + multiRangeCheck(q); + let r = compactMultiRangeCheck(r01, r2); + return r; +} + +function inverse(x: Field3, f: bigint): Field3 { + assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); + + // constant case + if (x.every((x) => x.isConstant())) { + let xInv = modInverse(Field3.toBigint(x), f); + assert(xInv !== undefined, 'inverse exists'); + return Field3.from(xInv); + } + + // provable case + let xInv = exists(3, () => { + let xInv = modInverse(Field3.toBigint(x), f); + return xInv === undefined ? [0n, 0n, 0n] : split(xInv); + }); + multiRangeCheck(xInv); + // we need to bound xInv because it's a multiplication input + let xInv2Bound = weakBound(xInv[2], f); + + let one: Field2 = [Field.from(1n), Field.from(0n)]; + assertMul(x, xInv, one, f); + + // range check on result bound + // TODO: this uses two RCs too many.. need global RC stack + multiRangeCheck([xInv2Bound, Field.from(0n), Field.from(0n)]); + + return xInv; +} + +function divide( + x: Field3, + y: Field3, + f: bigint, + { allowZeroOverZero = false } = {} +) { + assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); + + // constant case + if (x.every((x) => x.isConstant()) && y.every((x) => x.isConstant())) { + let yInv = modInverse(Field3.toBigint(y), f); + assert(yInv !== undefined, 'inverse exists'); + return Field3.from(mod(Field3.toBigint(x) * yInv, f)); + } + + // provable case + // to show that z = x/y, we prove that z*y = x and y != 0 (the latter avoids the unconstrained 0/0 case) + let z = exists(3, () => { + let yInv = modInverse(Field3.toBigint(y), f); + if (yInv === undefined) return [0n, 0n, 0n]; + return split(mod(Field3.toBigint(x) * yInv, f)); + }); + multiRangeCheck(z); + let z2Bound = weakBound(z[2], f); + assertMul(z, y, x, f); + + // range check on result bound + multiRangeCheck([z2Bound, Field.from(0n), Field.from(0n)]); + + if (!allowZeroOverZero) { + // assert that y != 0 mod f by checking that it doesn't equal 0 or f + // this works because we assume y[2] <= f2 + // TODO is this the most efficient way? + let y01 = y[0].add(y[1].mul(1n << l)); + y01.equals(0n).and(y[2].equals(0n)).assertFalse(); + let [f0, f1, f2] = split(f); + let f01 = combine2([f0, f1]); + y01.equals(f01).and(y[2].equals(f2)).assertFalse(); + } + + return z; +} + +/** + * Common logic for gadgets that expect a certain multiplication result a priori, instead of just using the remainder. + */ +function assertMul(x: Field3, y: Field3, xy: Field3 | Field2, f: bigint) { + let { r01, r2, q } = multiplyNoRangeCheck(x, y, f); + + // range check on quotient + multiRangeCheck(q); + + // bind remainder to input xy + if (xy.length === 2) { + let [xy01, xy2] = xy; + r01.assertEquals(xy01); + r2.assertEquals(xy2); + } else { + let xy01 = xy[0].add(xy[1].mul(1n << l)); + r01.assertEquals(xy01); + r2.assertEquals(xy[2]); + } +} + +/** + * Core building block for all gadgets using foreign field multiplication. + */ +function multiplyNoRangeCheck(a: Field3, b: Field3, f: bigint) { + // notation follows https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md + let f_ = (1n << l3) - f; + let [f_0, f_1, f_2] = split(f_); + let f2 = f >> l2; + let f2Bound = (1n << l) - f2 - 1n; + + let witnesses = exists(21, () => { + // convert inputs to bigints + let [a0, a1, a2] = toBigint3(a); + let [b0, b1, b2] = toBigint3(b); + + // compute q and r such that a*b = q*f + r + let ab = combine([a0, a1, a2]) * combine([b0, b1, b2]); + let q = ab / f; + let r = ab - q * f; + + let [q0, q1, q2] = split(q); + let [r0, r1, r2] = split(r); + let r01 = combine2([r0, r1]); + + // compute product terms + let p0 = a0 * b0 + q0 * f_0; + let p1 = a0 * b1 + a1 * b0 + q0 * f_1 + q1 * f_0; + let p2 = a0 * b2 + a1 * b1 + a2 * b0 + q0 * f_2 + q1 * f_1 + q2 * f_0; + + let [p10, p110, p111] = split(p1); + let p11 = combine2([p110, p111]); + + // carry bottom limbs + let c0 = (p0 + (p10 << l) - r01) >> l2; + + // carry top limb + let c1 = (p2 - r2 + p11 + c0) >> l; + + // split high carry + let c1_00 = bitSlice(c1, 0, 12); + let c1_12 = bitSlice(c1, 12, 12); + let c1_24 = bitSlice(c1, 24, 12); + let c1_36 = bitSlice(c1, 36, 12); + let c1_48 = bitSlice(c1, 48, 12); + let c1_60 = bitSlice(c1, 60, 12); + let c1_72 = bitSlice(c1, 72, 12); + let c1_84 = bitSlice(c1, 84, 2); + let c1_86 = bitSlice(c1, 86, 2); + let c1_88 = bitSlice(c1, 88, 2); + let c1_90 = bitSlice(c1, 90, 1); + + // quotient high bound + let q2Bound = q2 + f2Bound; + + // prettier-ignore + return [ + r01, r2, + q0, q1, q2, + q2Bound, + p10, p110, p111, + c0, + c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72, + c1_84, c1_86, c1_88, c1_90, + ]; + }); + + // prettier-ignore + let [ + r01, r2, + q0, q1, q2, + q2Bound, + p10, p110, p111, + c0, + c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72, + c1_84, c1_86, c1_88, c1_90, + ] = witnesses; + + let q: Field3 = [q0, q1, q2]; + + // ffmul gate. this already adds the following zero row. + Gates.foreignFieldMul({ + left: a, + right: b, + remainder: [r01, r2], + quotient: q, + quotientHiBound: q2Bound, + product1: [p10, p110, p111], + carry0: c0, + carry1p: [c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72], + carry1c: [c1_84, c1_86, c1_88, c1_90], + foreignFieldModulus2: f2, + negForeignFieldModulus: [f_0, f_1, f_2], + }); + + // multi-range check on internal values + multiRangeCheck([p10, p110, q2Bound]); + + // note: this function is supposed to be the most flexible interface to the ffmul gate. + // that's why we don't add range checks on q and r here, because there are valid use cases + // for not range-checking either of them -- for example, they could be wired to other + // variables that are already range-checked, or to constants / public inputs. + return { r01, r2, q }; +} + +function weakBound(x: Field, f: bigint) { + return x.add(lMask - (f >> l2)); +} + const Field3 = { /** * Turn a bigint into a 3-tuple of Fields */ from(x: bigint): Field3 { - return toField3(split(x)); + return Tuple.map(split(x), Field.from); }, /** * Turn a 3-tuple of Fields into a bigint */ toBigint(x: Field3): bigint { - return collapse(toBigint3(x)); + return combine(toBigint3(x)); }, /** @@ -116,23 +350,27 @@ const Field3 = { provable: provableTuple([Field, Field, Field]), }; -function toField3(x: bigint3): Field3 { - return Tuple.map(x, (x) => new Field(x)); -} +type Field2 = [Field, Field]; +const Field2 = { + toBigint(x: Field2): bigint { + return combine2(Tuple.map(x, (x) => x.toBigInt())); + }, +}; + function toBigint3(x: Field3): bigint3 { return Tuple.map(x, (x) => x.toBigInt()); } -function collapse([x0, x1, x2]: bigint3) { - return x0 + (x1 << L) + (x2 << twoL); +function combine([x0, x1, x2]: bigint3) { + return x0 + (x1 << l) + (x2 << l2); } function split(x: bigint): bigint3 { - return [x & lMask, (x >> L) & lMask, (x >> twoL) & lMask]; + return [x & lMask, (x >> l) & lMask, (x >> l2) & lMask]; } -function collapse2([x0, x1]: bigint3 | [bigint, bigint]) { - return x0 + (x1 << L); +function combine2([x0, x1]: bigint3 | [bigint, bigint]) { + return x0 + (x1 << l); } function split2(x: bigint): [bigint, bigint] { - return [x & lMask, (x >> L) & lMask]; + return [x & lMask, (x >> l) & lMask]; } diff --git a/src/lib/gadgets/foreign-field.unit-test.ts b/src/lib/gadgets/foreign-field.unit-test.ts index 0ff87927fc..4cb0d2d975 100644 --- a/src/lib/gadgets/foreign-field.unit-test.ts +++ b/src/lib/gadgets/foreign-field.unit-test.ts @@ -27,14 +27,32 @@ import { GateType } from '../../snarky.js'; const { ForeignField, Field3 } = Gadgets; function foreignField(F: FiniteField): ProvableSpec { - let rng = Random.otherField(F); return { - rng, + rng: Random.otherField(F), there: Field3.from, back: Field3.toBigint, provable: Field3.provable, }; } + +// for testing with inputs > f +function unreducedForeignField( + maxBits: number, + F: FiniteField +): ProvableSpec { + return { + rng: Random.bignat(1n << BigInt(maxBits)), + there: Field3.from, + back: Field3.toBigint, + provable: Field3.provable, + assertEqual(x, y, message) { + // need weak equality here because, while ffadd works on bigints larger than the modulus, + // it can't fully reduce them + assert(F.equal(x, y), message); + }, + }; +} + let sign = fromRandom(Random.oneOf(1n as const, -1n as const)); let fields = [ @@ -56,11 +74,57 @@ for (let F of fields) { eq2(F.add, (x, y) => ForeignField.add(x, y, F.modulus), 'add'); eq2(F.sub, (x, y) => ForeignField.sub(x, y, F.modulus), 'sub'); + eq2(F.mul, (x, y) => ForeignField.mul(x, y, F.modulus), 'mul'); + equivalentProvable({ from: [f], to: f })( + (x) => F.inverse(x) ?? throwError('no inverse'), + (x) => ForeignField.inv(x, F.modulus), + 'inv' + ); + eq2( + (x, y) => F.div(x, y) ?? throwError('no inverse'), + (x, y) => ForeignField.div(x, y, F.modulus), + 'div' + ); + + // tests with inputs that aren't reduced mod f + let big264 = unreducedForeignField(264, F); // this is the max size supported by our range checks / ffadd + let big258 = unreducedForeignField(258, F); // rough max size supported by ffmul + + equivalentProvable({ from: [big264, big264], to: big264 })( + F.add, + (x, y) => ForeignField.add(x, y, F.modulus), + 'add unreduced' + ); + // subtraction doesn't work with unreduced y because the range check on the result prevents x-y < -f + equivalentProvable({ from: [big264, f], to: big264 })( + F.sub, + (x, y) => ForeignField.sub(x, y, F.modulus), + 'sub unreduced' + ); + equivalentProvable({ from: [big258, big258], to: f })( + F.mul, + (x, y) => ForeignField.mul(x, y, F.modulus), + 'mul unreduced' + ); + equivalentProvable({ from: [big258], to: f })( + (x) => F.inverse(x) ?? throwError('no inverse'), + (x) => ForeignField.inv(x, F.modulus), + 'inv unreduced' + ); + // the div() gadget doesn't work with unreduced x because the backwards check (x/y)*y === x fails + // and it's not valid with unreduced y because we only assert y != 0, y != f but it can be 2f, 3f, etc. + // the combination of inv() and mul() is more flexible (but much more expensive, ~40 vs ~30 constraints) + equivalentProvable({ from: [big258, big258], to: f })( + (x, y) => F.div(x, y) ?? throwError('no inverse'), + (x, y) => ForeignField.mul(x, ForeignField.inv(y, F.modulus), F.modulus), + 'div unreduced' + ); // sumchain of 5 equivalentProvable({ from: [array(f, 5), array(sign, 4)], to: f })( (xs, signs) => sum(xs, signs, F), - (xs, signs) => ForeignField.sum(xs, signs, F.modulus) + (xs, signs) => ForeignField.sum(xs, signs, F.modulus), + 'sumchain 5' ); // sumchain up to 100 @@ -77,7 +141,7 @@ for (let F of fields) { let signs = ts.map((t) => t.sign); return ForeignField.sum(xs, signs, F.modulus); }, - 'sumchain' + 'sumchain long' ); } @@ -98,6 +162,24 @@ let ffProgram = ZkProgram({ return ForeignField.sum(xs, signs, F.modulus); }, }, + mul: { + privateInputs: [Field3.provable, Field3.provable], + method(x, y) { + return ForeignField.mul(x, y, F.modulus); + }, + }, + inv: { + privateInputs: [Field3.provable], + method(x) { + return ForeignField.inv(x, F.modulus); + }, + }, + div: { + privateInputs: [Field3.provable, Field3.provable], + method(x, y) { + return ForeignField.div(x, y, F.modulus); + }, + }, }, }); @@ -117,11 +199,29 @@ constraintSystem.fromZkProgram( ) ); +let mulChain: GateType[] = ['ForeignFieldMul', 'Zero']; +let mulLayout = ifNotAllConstant( + and( + contains([mulChain, mrc, mrc, mrc]), + withoutGenerics(equals([...mulChain, ...repeat(3, mrc)])) + ) +); +let invLayout = ifNotAllConstant( + and( + contains([mrc, mulChain, mrc, mrc, mrc]), + withoutGenerics(equals([...mrc, ...mulChain, ...repeat(3, mrc)])) + ) +); + +constraintSystem.fromZkProgram(ffProgram, 'mul', mulLayout); +constraintSystem.fromZkProgram(ffProgram, 'inv', invLayout); +constraintSystem.fromZkProgram(ffProgram, 'div', invLayout); + // tests with proving await ffProgram.compile(); -await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs: 5 })( +await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs: 3 })( (xs) => sum(xs, signs, F), async (xs) => { let proof = await ffProgram.sumchain(xs); @@ -131,6 +231,26 @@ await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs: 5 })( 'prove chain' ); +await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( + F.mul, + async (x, y) => { + let proof = await ffProgram.mul(x, y); + assert(await ffProgram.verify(proof), 'verifies'); + return proof.publicOutput; + }, + 'prove mul' +); + +await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( + (x, y) => F.div(x, y) ?? throwError('no inverse'), + async (x, y) => { + let proof = await ffProgram.div(x, y); + assert(await ffProgram.verify(proof), 'verifies'); + return proof.publicOutput; + }, + 'prove div' +); + // helper function sum(xs: bigint[], signs: (1n | -1n)[], F: FiniteField) { @@ -140,3 +260,7 @@ function sum(xs: bigint[], signs: (1n | -1n)[], F: FiniteField) { } return sum; } + +function throwError(message: string): T { + throw Error(message); +} diff --git a/src/lib/gadgets/gadgets.ts b/src/lib/gadgets/gadgets.ts index c7e1208a7f..b67ef58fd9 100644 --- a/src/lib/gadgets/gadgets.ts +++ b/src/lib/gadgets/gadgets.ts @@ -320,7 +320,7 @@ const Gadgets = { * A _foreign field_ is a finite field different from the native field of the proof system. * * The `ForeignField` namespace exposes operations like modular addition and multiplication, - * which work for any finite field of size less than 2^256. + * which work for any finite field of size less than 2^259. * * Foreign field elements are represented as 3 limbs of native field elements. * Each limb holds 88 bits of the total, in little-endian order. @@ -371,10 +371,7 @@ const Gadgets = { * * See {@link ForeignField.add} for assumptions and usage examples. * - * @param x left summand - * @param y right summand - * @param f modulus - * @returns x - y mod f + * @throws fails if `x - y < -f`, where the result cannot be brought back to a positive number by adding `f` once. */ sub(x: Field3, y: Field3, f: bigint) { return ForeignField.sub(x, y, f); @@ -412,6 +409,71 @@ const Gadgets = { sum(xs: Field3[], signs: (1n | -1n)[], f: bigint) { return ForeignField.sum(xs, signs, f); }, + + /** + * Foreign field multiplication: `x * y mod f` + * + * The modulus `f` does not need to be prime, but has to be smaller than 2^259. + * + * **Assumptions**: In addition to the assumption that input limbs are in the range [0, 2^88), as in all foreign field gadgets, + * this assumes an additional bound on the inputs: `x * y < 2^264 * p`, where p is the native modulus. + * We usually assert this bound by proving that `x[2] < f[2] + 1`, where `x[2]` is the most significant limb of x. + * To do this, use an 88-bit range check on `2^88 - x[2] - (f[2] + 1)`, and same for y. + * The implication is that x and y are _almost_ reduced modulo f. + * + * **Warning**: This gadget does not add the extra bound check on the result. + * So, to use the result in another foreign field multiplication, you have to add the bound check on it yourself, again. + * + * @example + * ```ts + * // example modulus: secp256k1 prime + * let f = (1n << 256n) - (1n << 32n) - 0b1111010001n; + * + * let x = Provable.witness(Field3.provable, () => Field3.from(f - 1n)); + * let y = Provable.witness(Field3.provable, () => Field3.from(f - 2n)); + * + * // range check x, y + * Gadgets.multiRangeCheck(x); + * Gadgets.multiRangeCheck(y); + * + * // prove additional bounds + * let x2Bound = x[2].add((1n << 88n) - 1n - (f >> 176n)); + * let y2Bound = y[2].add((1n << 88n) - 1n - (f >> 176n)); + * Gadgets.multiRangeCheck([x2Bound, y2Bound, Field(0n)]); + * + * // compute x * y mod f + * let z = ForeignField.mul(x, y, f); + * + * Provable.log(z); // ['2', '0', '0'] = limb representation of 2 = (-1)*(-2) mod f + * ``` + */ + mul(x: Field3, y: Field3, f: bigint) { + return ForeignField.mul(x, y, f); + }, + + /** + * Foreign field inverse: `x^(-1) mod f` + * + * See {@link ForeignField.mul} for assumptions on inputs and usage examples. + * + * This gadget adds an extra bound check on the result, so it can be used directly in another foreign field multiplication. + */ + inv(x: Field3, f: bigint) { + return ForeignField.inv(x, f); + }, + + /** + * Foreign field division: `x * y^(-1) mod f` + * + * See {@link ForeignField.mul} for assumptions on inputs and usage examples. + * + * This gadget adds an extra bound check on the result, so it can be used directly in another foreign field multiplication. + * + * @throws Different than {@link ForeignField.mul}, this fails on unreduced input `x`, because it checks that `x === (x/y)*y` and the right side will be reduced. + */ + div(x: Field3, y: Field3, f: bigint) { + return ForeignField.div(x, y, f); + }, }, /** diff --git a/src/lib/gadgets/range-check.ts b/src/lib/gadgets/range-check.ts index 5b2be2c8ee..1e44afdaf9 100644 --- a/src/lib/gadgets/range-check.ts +++ b/src/lib/gadgets/range-check.ts @@ -2,15 +2,8 @@ import { Field } from '../field.js'; import { Gates } from '../gates.js'; import { bitSlice, exists, toVar, toVars } from './common.js'; -export { - rangeCheck64, - multiRangeCheck, - compactMultiRangeCheck, - L, - twoL, - lMask, - twoLMask, -}; +export { rangeCheck64, multiRangeCheck, compactMultiRangeCheck }; +export { l, l2, l3, lMask, l2Mask }; /** * Asserts that x is in the range [0, 2^64) @@ -58,18 +51,19 @@ function rangeCheck64(x: Field) { } // default bigint limb size -const L = 88n; -const twoL = 2n * L; -const lMask = (1n << L) - 1n; -const twoLMask = (1n << twoL) - 1n; +const l = 88n; +const l2 = 2n * l; +const l3 = 3n * l; +const lMask = (1n << l) - 1n; +const l2Mask = (1n << l2) - 1n; /** * Asserts that x, y, z \in [0, 2^88) */ function multiRangeCheck([x, y, z]: [Field, Field, Field]) { if (x.isConstant() && y.isConstant() && z.isConstant()) { - if (x.toBigInt() >> L || y.toBigInt() >> L || z.toBigInt() >> L) { - throw Error(`Expected fields to fit in ${L} bits, got ${x}, ${y}, ${z}`); + if (x.toBigInt() >> l || y.toBigInt() >> l || z.toBigInt() >> l) { + throw Error(`Expected fields to fit in ${l} bits, got ${x}, ${y}, ${z}`); } return; } @@ -92,9 +86,9 @@ function multiRangeCheck([x, y, z]: [Field, Field, Field]) { function compactMultiRangeCheck(xy: Field, z: Field): [Field, Field, Field] { // constant case if (xy.isConstant() && z.isConstant()) { - if (xy.toBigInt() >> twoL || z.toBigInt() >> L) { + if (xy.toBigInt() >> l2 || z.toBigInt() >> l) { throw Error( - `Expected fields to fit in ${twoL} and ${L} bits respectively, got ${xy}, ${z}` + `Expected fields to fit in ${l2} and ${l} bits respectively, got ${xy}, ${z}` ); } let [x, y] = splitCompactLimb(xy.toBigInt()); @@ -113,7 +107,7 @@ function compactMultiRangeCheck(xy: Field, z: Field): [Field, Field, Field] { } function splitCompactLimb(x01: bigint): [bigint, bigint] { - return [x01 & lMask, x01 >> L]; + return [x01 & lMask, x01 >> l]; } function rangeCheck0Helper(x: Field, isCompact = false): [Field, Field] { diff --git a/src/lib/gadgets/range-check.unit-test.ts b/src/lib/gadgets/range-check.unit-test.ts index b5d5110807..47aafbf592 100644 --- a/src/lib/gadgets/range-check.unit-test.ts +++ b/src/lib/gadgets/range-check.unit-test.ts @@ -10,7 +10,7 @@ import { import { Random } from '../testing/property.js'; import { assert } from './common.js'; import { Gadgets } from './gadgets.js'; -import { L } from './range-check.js'; +import { l } from './range-check.js'; import { constraintSystem, contains, @@ -82,7 +82,7 @@ let RangeCheck = ZkProgram({ privateInputs: [Field, Field], method(xy, z) { let [x, y] = Gadgets.compactMultiRangeCheck(xy, z); - x.add(y.mul(1n << L)).assertEquals(xy); + x.add(y.mul(1n << l)).assertEquals(xy); }, }, }, @@ -104,11 +104,11 @@ await equivalentAsync({ from: [maybeUint(64)], to: boolean }, { runs: 3 })( ); await equivalentAsync( - { from: [maybeUint(L), uint(L), uint(L)], to: boolean }, + { from: [maybeUint(l), uint(l), uint(l)], to: boolean }, { runs: 3 } )( (x, y, z) => { - assert(!(x >> L) && !(y >> L) && !(z >> L), 'multi: not out of range'); + assert(!(x >> l) && !(y >> l) && !(z >> l), 'multi: not out of range'); return true; }, async (x, y, z) => { @@ -118,11 +118,11 @@ await equivalentAsync( ); await equivalentAsync( - { from: [maybeUint(2n * L), uint(L)], to: boolean }, + { from: [maybeUint(2n * l), uint(l)], to: boolean }, { runs: 3 } )( (xy, z) => { - assert(!(xy >> (2n * L)) && !(z >> L), 'compact: not out of range'); + assert(!(xy >> (2n * l)) && !(z >> l), 'compact: not out of range'); return true; }, async (xy, z) => { diff --git a/src/lib/gates.ts b/src/lib/gates.ts index d0be6abf17..5d620066c5 100644 --- a/src/lib/gates.ts +++ b/src/lib/gates.ts @@ -1,4 +1,4 @@ -import { Snarky } from '../snarky.js'; +import { KimchiGateType, Snarky } from '../snarky.js'; import { FieldConst, type Field } from './field.js'; import { MlArray, MlTuple } from './ml/base.js'; import { TupleN } from './util/types.js'; @@ -12,6 +12,7 @@ export { rotate, generic, foreignFieldAdd, + foreignFieldMul, }; const Gates = { @@ -22,6 +23,8 @@ const Gates = { rotate, generic, foreignFieldAdd, + foreignFieldMul, + raw, }; function rangeCheck0( @@ -184,3 +187,56 @@ function foreignFieldAdd({ FieldConst.fromBigint(sign) ); } + +/** + * Foreign field multiplication + */ +function foreignFieldMul(inputs: { + left: TupleN; + right: TupleN; + remainder: TupleN; + quotient: TupleN; + quotientHiBound: Field; + product1: TupleN; + carry0: Field; + carry1p: TupleN; + carry1c: TupleN; + foreignFieldModulus2: bigint; + negForeignFieldModulus: TupleN; +}) { + let { + left, + right, + remainder, + quotient, + quotientHiBound, + product1, + carry0, + carry1p, + carry1c, + foreignFieldModulus2, + negForeignFieldModulus, + } = inputs; + + Snarky.gates.foreignFieldMul( + MlTuple.mapTo(left, (x) => x.value), + MlTuple.mapTo(right, (x) => x.value), + MlTuple.mapTo(remainder, (x) => x.value), + MlTuple.mapTo(quotient, (x) => x.value), + quotientHiBound.value, + MlTuple.mapTo(product1, (x) => x.value), + carry0.value, + MlTuple.mapTo(carry1p, (x) => x.value), + MlTuple.mapTo(carry1c, (x) => x.value), + FieldConst.fromBigint(foreignFieldModulus2), + MlTuple.mapTo(negForeignFieldModulus, FieldConst.fromBigint) + ); +} + +function raw(kind: KimchiGateType, values: Field[], coefficients: bigint[]) { + Snarky.gates.raw( + kind, + MlArray.to(values.map((x) => x.value)), + MlArray.to(coefficients.map(FieldConst.fromBigint)) + ); +} diff --git a/src/lib/proof_system.ts b/src/lib/proof_system.ts index a4f24dbb03..c952b7aa6d 100644 --- a/src/lib/proof_system.ts +++ b/src/lib/proof_system.ts @@ -268,7 +268,9 @@ function ZkProgram< > ) => Promise; digest: () => string; - analyzeMethods: () => ReturnType[]; + analyzeMethods: () => { + [I in keyof Types]: ReturnType; + }; publicInputType: ProvableOrUndefined>; publicOutputType: ProvableOrVoid>; privateInputTypes: { @@ -316,9 +318,14 @@ function ZkProgram< let maxProofsVerified = getMaxProofsVerified(methodIntfs); function analyzeMethods() { - return methodIntfs.map((methodEntry, i) => - analyzeMethod(publicInputType, methodEntry, methodFunctions[i]) - ); + return Object.fromEntries( + methodIntfs.map((methodEntry, i) => [ + methodEntry.methodName, + analyzeMethod(publicInputType, methodEntry, methodFunctions[i]), + ]) + ) as any as { + [I in keyof Types]: ReturnType; + }; } let compileOutput: @@ -332,7 +339,9 @@ function ZkProgram< | undefined; async function compile({ cache = Cache.FileSystemDefault } = {}) { - let methodsMeta = analyzeMethods(); + let methodsMeta = methodIntfs.map((methodEntry, i) => + analyzeMethod(publicInputType, methodEntry, methodFunctions[i]) + ); let gates = methodsMeta.map((m) => m.gates); let { provers, verify, verificationKey } = await compileProgram({ publicInputType, diff --git a/src/lib/proof_system.unit-test.ts b/src/lib/proof_system.unit-test.ts index 477d354114..b89f1f3dd3 100644 --- a/src/lib/proof_system.unit-test.ts +++ b/src/lib/proof_system.unit-test.ts @@ -16,14 +16,12 @@ const EmptyProgram = ZkProgram({ }); const emptyMethodsMetadata = EmptyProgram.analyzeMethods(); -emptyMethodsMetadata.forEach((methodMetadata) => { - expect(methodMetadata).toEqual({ - rows: 0, - digest: '4f5ddea76d29cfcfd8c595f14e31f21b', - result: undefined, - gates: [], - publicInputSize: 0, - }); +expect(emptyMethodsMetadata.run).toEqual({ + rows: 0, + digest: '4f5ddea76d29cfcfd8c595f14e31f21b', + result: undefined, + gates: [], + publicInputSize: 0, }); class CounterPublicInput extends Struct({ @@ -47,5 +45,5 @@ const CounterProgram = ZkProgram({ }, }); -const incrementMethodMetadata = CounterProgram.analyzeMethods()[0]; +const incrementMethodMetadata = CounterProgram.analyzeMethods().increment; expect(incrementMethodMetadata).toEqual(expect.objectContaining({ rows: 18 })); diff --git a/src/snarky.d.ts b/src/snarky.d.ts index 0408c637d5..69fe3bcb37 100644 --- a/src/snarky.d.ts +++ b/src/snarky.d.ts @@ -33,6 +33,7 @@ export { Snarky, Test, JsonGate, + KimchiGateType, MlPublicKey, MlPublicKeyVar, FeatureFlags,