Skip to content

Commit

Permalink
Merge pull request #1215 from o1-labs/feature/expose-all-gates
Browse files Browse the repository at this point in the history
Expose remaining gates
  • Loading branch information
mitschabaude authored Nov 1, 2023
2 parents bde4fc0 + 7621ea6 commit 56975fc
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 143 deletions.
2 changes: 1 addition & 1 deletion src/bindings
55 changes: 28 additions & 27 deletions src/lib/gadgets/bitwise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import * as Gates from '../gates.js';
import {
MAX_BITS,
assert,
witnessSlices,
witnessSlice,
witnessNextValue,
divideWithRemainder,
} from './common.js';
import { rangeCheck64 } from './range-check.js';

export { xor, and, rotate };

Expand Down Expand Up @@ -66,22 +67,22 @@ function buildXor(
while (padLength !== 0) {
// slices the inputs into 4x 4bit-sized chunks
// slices of a
let in1_0 = witnessSlices(a, 0, 4);
let in1_1 = witnessSlices(a, 4, 4);
let in1_2 = witnessSlices(a, 8, 4);
let in1_3 = witnessSlices(a, 12, 4);
let in1_0 = witnessSlice(a, 0, 4);
let in1_1 = witnessSlice(a, 4, 4);
let in1_2 = witnessSlice(a, 8, 4);
let in1_3 = witnessSlice(a, 12, 4);

// slices of b
let in2_0 = witnessSlices(b, 0, 4);
let in2_1 = witnessSlices(b, 4, 4);
let in2_2 = witnessSlices(b, 8, 4);
let in2_3 = witnessSlices(b, 12, 4);
let in2_0 = witnessSlice(b, 0, 4);
let in2_1 = witnessSlice(b, 4, 4);
let in2_2 = witnessSlice(b, 8, 4);
let in2_3 = witnessSlice(b, 12, 4);

// slices of expected output
let out0 = witnessSlices(expectedOutput, 0, 4);
let out1 = witnessSlices(expectedOutput, 4, 4);
let out2 = witnessSlices(expectedOutput, 8, 4);
let out3 = witnessSlices(expectedOutput, 12, 4);
let out0 = witnessSlice(expectedOutput, 0, 4);
let out1 = witnessSlice(expectedOutput, 4, 4);
let out2 = witnessSlice(expectedOutput, 8, 4);
let out3 = witnessSlice(expectedOutput, 12, 4);

// assert that the xor of the slices is correct, 16 bit at a time
Gates.xor(
Expand Down Expand Up @@ -221,26 +222,26 @@ function rot(
rotated,
excess,
[
witnessSlices(bound, 52, 12), // bits 52-64
witnessSlices(bound, 40, 12), // bits 40-52
witnessSlices(bound, 28, 12), // bits 28-40
witnessSlices(bound, 16, 12), // bits 16-28
witnessSlice(bound, 52, 12), // bits 52-64
witnessSlice(bound, 40, 12), // bits 40-52
witnessSlice(bound, 28, 12), // bits 28-40
witnessSlice(bound, 16, 12), // bits 16-28
],
[
witnessSlices(bound, 14, 2), // bits 14-16
witnessSlices(bound, 12, 2), // bits 12-14
witnessSlices(bound, 10, 2), // bits 10-12
witnessSlices(bound, 8, 2), // bits 8-10
witnessSlices(bound, 6, 2), // bits 6-8
witnessSlices(bound, 4, 2), // bits 4-6
witnessSlices(bound, 2, 2), // bits 2-4
witnessSlices(bound, 0, 2), // bits 0-2
witnessSlice(bound, 14, 2), // bits 14-16
witnessSlice(bound, 12, 2), // bits 12-14
witnessSlice(bound, 10, 2), // bits 10-12
witnessSlice(bound, 8, 2), // bits 8-10
witnessSlice(bound, 6, 2), // bits 6-8
witnessSlice(bound, 4, 2), // bits 4-6
witnessSlice(bound, 2, 2), // bits 2-4
witnessSlice(bound, 0, 2), // bits 0-2
],
big2PowerRot
);
// Compute next row
Gates.rangeCheck64(shifted);
rangeCheck64(shifted);
// Compute following row
Gates.rangeCheck64(excess);
rangeCheck64(excess);
return [rotated, excess, shifted];
}
26 changes: 23 additions & 3 deletions src/lib/gadgets/common.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
import { Provable } from '../provable.js';
import { Field } from '../field.js';
import { Field, FieldConst } from '../field.js';
import { TupleN } from '../util/types.js';
import { Snarky } from '../../snarky.js';
import { MlArray } from '../ml/base.js';

const MAX_BITS = 64 as const;

export {
MAX_BITS,
exists,
assert,
witnessSlices,
bitSlice,
witnessSlice,
witnessNextValue,
divideWithRemainder,
};

function exists<N extends number, C extends () => TupleN<bigint, N>>(
n: N,
compute: C
) {
let varsMl = Snarky.exists(n, () =>
MlArray.mapTo(compute(), FieldConst.fromBigint)
);
let vars = MlArray.mapFrom(varsMl, (v) => new Field(v));
return TupleN.fromArray(n, vars);
}

function assert(stmt: boolean, message?: string) {
if (!stmt) {
throw Error(message ?? 'Assertion failed');
}
}

function witnessSlices(f: Field, start: number, length: number) {
function bitSlice(x: bigint, start: number, length: number) {
return (x >> BigInt(start)) & ((1n << BigInt(length)) - 1n);
}

function witnessSlice(f: Field, start: number, length: number) {
if (length <= 0) throw Error('Length must be a positive number');

return Provable.witness(Field, () => {
Expand Down
39 changes: 36 additions & 3 deletions src/lib/gadgets/range-check.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { type Field } from '../field.js';
import { Field } from '../field.js';
import * as Gates from '../gates.js';
import { bitSlice, exists } from './common.js';

export { rangeCheck64 };

Expand All @@ -11,7 +12,39 @@ function rangeCheck64(x: Field) {
if (x.toBigInt() >= 1n << 64n) {
throw Error(`rangeCheck64: expected field to fit in 64 bits, got ${x}`);
}
} else {
Gates.rangeCheck64(x);
return;
}

// crumbs (2-bit limbs)
let [x0, x2, x4, x6, x8, x10, x12, x14] = exists(8, () => {
let xx = x.toBigInt();
return [
bitSlice(xx, 0, 2),
bitSlice(xx, 2, 2),
bitSlice(xx, 4, 2),
bitSlice(xx, 6, 2),
bitSlice(xx, 8, 2),
bitSlice(xx, 10, 2),
bitSlice(xx, 12, 2),
bitSlice(xx, 14, 2),
];
});

// 12-bit limbs
let [x16, x28, x40, x52] = exists(4, () => {
let xx = x.toBigInt();
return [
bitSlice(xx, 16, 12),
bitSlice(xx, 28, 12),
bitSlice(xx, 40, 12),
bitSlice(xx, 52, 12),
];
});

Gates.rangeCheck0(
x,
[new Field(0), new Field(0), x52, x40, x28, x16],
[x14, x12, x10, x8, x6, x4, x2, x0],
false // not using compact mode
);
}
56 changes: 13 additions & 43 deletions src/lib/gates.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,21 @@
import { Snarky } from '../snarky.js';
import { FieldVar, FieldConst, type Field } from './field.js';
import { MlArray } from './ml/base.js';
import { FieldConst, type Field } from './field.js';
import { MlArray, MlTuple } from './ml/base.js';
import { TupleN } from './util/types.js';

export { rangeCheck64, xor, zero, rotate, generic };
export { rangeCheck0, xor, zero, rotate, generic };

/**
* Asserts that x is at most 64 bits
*/
function rangeCheck64(x: Field) {
let [, x0, x2, x4, x6, x8, x10, x12, x14] = Snarky.exists(8, () => {
let xx = x.toBigInt();
// crumbs (2-bit limbs)
return [
0,
getBits(xx, 0, 2),
getBits(xx, 2, 2),
getBits(xx, 4, 2),
getBits(xx, 6, 2),
getBits(xx, 8, 2),
getBits(xx, 10, 2),
getBits(xx, 12, 2),
getBits(xx, 14, 2),
];
});
// 12-bit limbs
let [, x16, x28, x40, x52] = Snarky.exists(4, () => {
let xx = x.toBigInt();
return [
0,
getBits(xx, 16, 12),
getBits(xx, 28, 12),
getBits(xx, 40, 12),
getBits(xx, 52, 12),
];
});
function rangeCheck0(
x: Field,
xLimbs12: TupleN<Field, 6>,
xLimbs2: TupleN<Field, 8>,
isCompact: boolean
) {
Snarky.gates.rangeCheck0(
x.value,
[0, FieldVar[0], FieldVar[0], x52, x40, x28, x16],
[0, x14, x12, x10, x8, x6, x4, x2, x0],
// not using compact mode
FieldConst[0]
MlTuple.mapTo(xLimbs12, (x) => x.value),
MlTuple.mapTo(xLimbs2, (x) => x.value),
isCompact ? FieldConst[1] : FieldConst[0]
);
}

Expand Down Expand Up @@ -136,9 +112,3 @@ function generic(
function zero(a: Field, b: Field, c: Field) {
Snarky.gates.zero(a.value, b.value, c.value);
}

function getBits(x: bigint, start: number, length: number) {
return FieldConst.fromBigint(
(x >> BigInt(start)) & ((1n << BigInt(length)) - 1n)
);
}
2 changes: 1 addition & 1 deletion src/lib/group.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Group {
return s.mul(x1.sub(x3)).sub(y1);
});

let [, x, y] = Snarky.group.ecadd(
let [, x, y] = Snarky.gates.ecAdd(
Group.from(x1.seal(), y1.seal()).#toTuple(),
Group.from(x2.seal(), y2.seal()).#toTuple(),
Group.from(x3, y3).#toTuple(),
Expand Down
55 changes: 48 additions & 7 deletions src/lib/ml/base.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import { TupleN } from '../util/types.js';

/**
* This module contains basic methods for interacting with OCaml
*/
export {
MlArray,
MlTuple,
MlPair,
MlList,
MlOption,
MlBool,
MlBytes,
MlResult,
MlUnit,
MlString,
MlTuple,
};

// ocaml types

type MlTuple<X, Y> = [0, X, Y];
type MlPair<X, Y> = [0, X, Y];
type MlArray<T> = [0, ...T[]];
type MlList<T> = [0, T, 0 | MlList<T>];
type MlOption<T> = 0 | [0, T];
Expand Down Expand Up @@ -48,18 +51,18 @@ const MlArray = {
},
};

const MlTuple = Object.assign(
function MlTuple<X, Y>(x: X, y: Y): MlTuple<X, Y> {
const MlPair = Object.assign(
function MlTuple<X, Y>(x: X, y: Y): MlPair<X, Y> {
return [0, x, y];
},
{
from<X, Y>([, x, y]: MlTuple<X, Y>): [X, Y] {
from<X, Y>([, x, y]: MlPair<X, Y>): [X, Y] {
return [x, y];
},
first<X>(t: MlTuple<X, unknown>): X {
first<X>(t: MlPair<X, unknown>): X {
return t[1];
},
second<Y>(t: MlTuple<unknown, Y>): Y {
second<Y>(t: MlPair<unknown, Y>): Y {
return t[2];
},
}
Expand Down Expand Up @@ -113,3 +116,41 @@ const MlResult = {
return [1, 0];
},
};

/**
* tuple type that has the length as generic parameter
*/
type MlTuple<T, N extends number> = N extends N
? number extends N
? [0, ...T[]] // N is not typed as a constant => fall back to array
: [0, ...TupleRec<T, N, []>]
: never;

type TupleRec<T, N extends number, R extends unknown[]> = R['length'] extends N
? R
: TupleRec<T, N, [T, ...R]>;

type Tuple<T> = [T, ...T[]] | [];

const MlTuple = {
map<T extends Tuple<any>, B>(
[, ...mlTuple]: [0, ...T],
f: (a: T[number]) => B
): [0, ...{ [i in keyof T]: B }] {
return [0, ...mlTuple.map(f)] as any;
},

mapFrom<T, N extends number, B>(
[, ...mlTuple]: MlTuple<T, N>,
f: (a: T) => B
): B[] {
return mlTuple.map(f);
},

mapTo<T extends Tuple<any> | TupleN<any, any>, B>(
tuple: T,
f: (a: T[number]) => B
): [0, ...{ [i in keyof T]: B }] {
return [0, ...tuple.map(f)] as any;
},
};
8 changes: 4 additions & 4 deletions src/lib/ml/conversion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { Bool, Field } from '../core.js';
import { FieldConst, FieldVar } from '../field.js';
import { Scalar, ScalarConst } from '../scalar.js';
import { PrivateKey, PublicKey } from '../signature.js';
import { MlTuple, MlBool, MlArray } from './base.js';
import { MlPair, MlBool, MlArray } from './base.js';
import { MlFieldConstArray } from './fields.js';

export { Ml, MlHashInput };
Expand All @@ -35,7 +35,7 @@ const Ml = {
type MlHashInput = [
flag: 0,
field_elements: MlArray<FieldConst>,
packed: MlArray<MlTuple<FieldConst, number>>
packed: MlArray<MlPair<FieldConst, number>>
];

const MlHashInput = {
Expand Down Expand Up @@ -86,7 +86,7 @@ function toPrivateKey(sk: ScalarConst) {
}

function fromPublicKey(pk: PublicKey): MlPublicKey {
return MlTuple(pk.x.toConstant().value[1], MlBool(pk.isOdd.toBoolean()));
return MlPair(pk.x.toConstant().value[1], MlBool(pk.isOdd.toBoolean()));
}
function toPublicKey([, x, isOdd]: MlPublicKey): PublicKey {
return PublicKey.from({
Expand All @@ -96,7 +96,7 @@ function toPublicKey([, x, isOdd]: MlPublicKey): PublicKey {
}

function fromPublicKeyVar(pk: PublicKey): MlPublicKeyVar {
return MlTuple(pk.x.value, pk.isOdd.toField().value);
return MlPair(pk.x.value, pk.isOdd.toField().value);
}
function toPublicKeyVar([, x, isOdd]: MlPublicKeyVar): PublicKey {
return PublicKey.from({ x: Field(x), isOdd: Bool(isOdd) });
Expand Down
Loading

0 comments on commit 56975fc

Please sign in to comment.