Skip to content

Commit

Permalink
Merge pull request #1672 from o1-labs/feature/offchain-state-with-ind…
Browse files Browse the repository at this point in the history
…exed-map

Use IndexedMerkleMap for OffchainState
  • Loading branch information
mitschabaude committed Jun 5, 2024
2 parents 1536a08 + 375d21f commit c573569
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 146 deletions.
6 changes: 5 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ export { Gadgets } from './lib/provable/gadgets/gadgets.js';
export { Types } from './bindings/mina-transaction/types.js';

export { MerkleList, MerkleListIterator } from './lib/provable/merkle-list.js';
import { IndexedMerkleMap } from './lib/provable/merkle-tree-indexed.js';
import {
IndexedMerkleMap,
IndexedMerkleMapBase,
} from './lib/provable/merkle-tree-indexed.js';
export { Option } from './lib/provable/option.js';

export * as Mina from './lib/mina/mina.js';
Expand Down Expand Up @@ -146,6 +149,7 @@ namespace Experimental {

// indexed merkle map
export let IndexedMerkleMap = Experimental_.IndexedMerkleMap;
export type IndexedMerkleMap = IndexedMerkleMapBase;

// offchain state
export let OffchainState = OffchainState_.OffchainState;
Expand Down
20 changes: 9 additions & 11 deletions src/lib/mina/actions/offchain-contract.unit-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import {
SmartContract,
method,
Mina,
State,
state,
PublicKey,
UInt64,
Expand All @@ -12,23 +11,22 @@ import assert from 'assert';

const proofsEnabled = true;

const { OffchainState, OffchainStateCommitments } = Experimental;
const { OffchainState } = Experimental;

const offchainState = OffchainState({
accounts: OffchainState.Map(PublicKey, UInt64),
totalSupply: OffchainState.Field(UInt64),
});
const offchainState = OffchainState(
{
accounts: OffchainState.Map(PublicKey, UInt64),
totalSupply: OffchainState.Field(UInt64),
},
{ logTotalCapacity: 10, maxActionsPerProof: 5 }
);

class StateProof extends offchainState.Proof {}

// example contract that interacts with offchain state

class ExampleContract extends SmartContract {
// TODO could have sugar for this like
// @OffchainState.commitment offchainState = OffchainState.Commitment();
@state(OffchainStateCommitments) offchainState = State(
OffchainStateCommitments.empty()
);
@state(OffchainState.Commitments) offchainState = offchainState.commitments();

@method
async createAccount(address: PublicKey, amountToMint: UInt64) {
Expand Down
166 changes: 71 additions & 95 deletions src/lib/mina/actions/offchain-state-rollup.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { Proof, ZkProgram } from '../../proof-system/zkprogram.js';
import { Bool, Field } from '../../provable/wrapped.js';
import { Unconstrained } from '../../provable/types/unconstrained.js';
import { MerkleList, MerkleListIterator } from '../../provable/merkle-list.js';
import { Actions } from '../../../bindings/mina-transaction/transaction-leaves.js';
import { MerkleTree, MerkleWitness } from '../../provable/merkle-tree.js';
import {
IndexedMerkleMap,
IndexedMerkleMapBase,
} from '../../provable/merkle-tree-indexed.js';
import { Struct } from '../../provable/types/struct.js';
import { SelfProof } from '../../proof-system/zkprogram.js';
import { Provable } from '../../provable/provable.js';
Expand All @@ -15,7 +17,6 @@ import {
MerkleLeaf,
updateMerkleMap,
} from './offchain-state-serialization.js';
import { MerkleMap } from '../../provable/merkle-map.js';
import { getProofsEnabled } from '../mina.js';

export { OffchainStateRollup, OffchainStateCommitments };
Expand Down Expand Up @@ -43,42 +44,39 @@ class OffchainStateCommitments extends Struct({
// actionState: ActionIterator.provable,
actionState: Field,
}) {
static empty() {
let emptyMerkleRoot = new MerkleMap().getRoot();
static emptyFromHeight(height: number) {
let emptyMerkleRoot = new (IndexedMerkleMap(height))().root;
return new OffchainStateCommitments({
root: emptyMerkleRoot,
actionState: Actions.emptyActionState(),
});
}
}

const TREE_HEIGHT = 256;
class MerkleMapWitness extends MerkleWitness(TREE_HEIGHT) {}

// TODO: it would be nice to abstract the logic for proving a chain of state transition proofs

/**
* Common logic for the proof that we can go from OffchainStateCommitments A -> B
*/
function merkleUpdateBatch(
{
maxActionsPerBatch,
maxActionsPerProof,
maxActionsPerUpdate,
}: {
maxActionsPerBatch: number;
maxActionsPerProof: number;
maxActionsPerUpdate: number;
},
stateA: OffchainStateCommitments,
actions: ActionIterator,
tree: Unconstrained<MerkleTree>
tree: IndexedMerkleMapBase
): OffchainStateCommitments {
// this would be unnecessary if the iterator could just be the public input
actions.currentHash.assertEquals(stateA.actionState);

// linearize actions into a flat MerkleList, so we don't process an insane amount of dummy actions
let linearActions = LinearizedActionList.empty();

for (let i = 0; i < maxActionsPerBatch; i++) {
for (let i = 0; i < maxActionsPerProof; i++) {
let inner = actions.next().startIterating();
let isAtEnd = Bool(false);
for (let i = 0; i < maxActionsPerUpdate; i++) {
Expand All @@ -99,96 +97,66 @@ function merkleUpdateBatch(
}
actions.assertAtEnd();

// update merkle root at once for the actions of each account update
let root = stateA.root;
let intermediateRoot = root;

let intermediateUpdates: { key: Field; value: Field }[] = [];
let intermediateTree = Unconstrained.witness(() => tree.get().clone());
// tree must match the public Merkle root; the method operates on the tree internally
// TODO: this would be simpler if the tree was the public input directly
stateA.root.assertEquals(tree.root);

let intermediateTree = tree.clone();
let isValidUpdate = Bool(true);

linearActions.forEach(maxActionsPerBatch, (element, isDummy) => {
linearActions.forEach(maxActionsPerProof, (element, isDummy) => {
let { action, isCheckPoint } = element;
let { key, value, usesPreviousValue, previousValue } = action;

// merkle witness
let witness = Provable.witness(
MerkleMapWitness,
() =>
new MerkleMapWitness(intermediateTree.get().getWitness(key.toBigInt()))
);
// make sure that if this is a dummy action, we use the canonical dummy (key, value) pair
key = Provable.if(isDummy, Field(0n), key);
value = Provable.if(isDummy, Field(0n), value);

// previous value at the key
let actualPreviousValue = Provable.witness(Field, () =>
intermediateTree.get().getLeaf(key.toBigInt())
);

// prove that the witness and `actualPreviousValue` is correct, by comparing the implied root and key
// note: this just works if the (key, value) is a (0,0) dummy, because the value at the 0 key will always be 0
witness.calculateIndex().assertEquals(key, 'key mismatch');
witness
.calculateRoot(actualPreviousValue)
.assertEquals(intermediateRoot, 'root mismatch');
// set (key, value) in the intermediate tree
// note: this just works if (key, value) is a (0,0) dummy, because the value at the 0 key will always be 0
let actualPreviousValue = intermediateTree.set(key, value);

// if an expected previous value was provided, check whether it matches the actual previous value
// otherwise, the entire update in invalidated
let matchesPreviousValue = actualPreviousValue.equals(previousValue);
let matchesPreviousValue = actualPreviousValue
.orElse(0n)
.equals(previousValue);
let isValidAction = usesPreviousValue.implies(matchesPreviousValue);
isValidUpdate = isValidUpdate.and(isValidAction);

// store new value in at the key
let newRoot = witness.calculateRoot(value);

// update intermediate root if this wasn't a dummy action
intermediateRoot = Provable.if(isDummy, intermediateRoot, newRoot);
// at checkpoints, update the tree, if the entire update was valid
tree.overwriteIf(isCheckPoint.and(isValidUpdate), intermediateTree);

// at checkpoints, update the root, if the entire update was valid
root = Provable.if(isCheckPoint.and(isValidUpdate), intermediateRoot, root);
// at checkpoints, reset intermediate values
let wasValidUpdate = isValidUpdate;
isValidUpdate = Provable.if(isCheckPoint, Bool(true), isValidUpdate);
intermediateRoot = Provable.if(isCheckPoint, root, intermediateRoot);

// update the tree, outside the circuit (this should all be part of a better merkle tree API)
Provable.asProver(() => {
// ignore dummy value
if (isDummy.toBoolean()) return;

intermediateTree.get().setLeaf(key.toBigInt(), value.toConstant());
intermediateUpdates.push({ key, value });

if (isCheckPoint.toBoolean()) {
// if the update was valid, apply the intermediate updates to the actual tree
if (wasValidUpdate.toBoolean()) {
intermediateUpdates.forEach(({ key, value }) => {
tree.get().setLeaf(key.toBigInt(), value.toConstant());
});
}
// otherwise, we have to roll back the intermediate tree (TODO: inefficient)
else {
intermediateTree.set(tree.get().clone());
}
intermediateUpdates = [];
}
});
intermediateTree.overwriteIf(isCheckPoint, tree);
});

return { root, actionState: actions.currentHash };
return { root: tree.root, actionState: actions.currentHash };
}

/**
* This program represents a proof that we can go from OffchainStateCommitments A -> B
*/
function OffchainStateRollup({
// 1 action uses about 7.5k constraints
// we can fit at most 7 * 7.5k = 52.5k constraints in one method next to proof verification
// => we use `maxActionsPerBatch = 6` to safely stay below the constraint limit
// the second parameter `maxActionsPerUpdate` only weakly affects # constraints, but has to be <= `maxActionsPerBatch`
// => so we set it to the same value
maxActionsPerBatch = 6,
maxActionsPerUpdate = 6,
/**
* the constraints used in one batch proof with a height-31 tree are:
*
* 1967*A + 87*A*U + 2
*
* where A = maxActionsPerProof and U = maxActionsPerUpdate.
*
* To determine defaults, we set U=4 which should cover most use cases while ensuring
* that the main loop which is independent of U dominates.
*
* Targeting ~50k constraints, to leave room for recursive verification, yields A=22.
*/
maxActionsPerProof = 22,
maxActionsPerUpdate = 4,
logTotalCapacity = 30,
} = {}) {
class IndexedMerkleMapN extends IndexedMerkleMap(logTotalCapacity + 1) {}

let offchainStateRollup = ZkProgram({
name: 'merkle-map-rollup',
publicInput: OffchainStateCommitments,
Expand All @@ -199,15 +167,15 @@ function OffchainStateRollup({
*/
firstBatch: {
// [actions, tree]
privateInputs: [ActionIterator.provable, Unconstrained.provable],
privateInputs: [ActionIterator.provable, IndexedMerkleMapN.provable],

async method(
stateA: OffchainStateCommitments,
actions: ActionIterator,
tree: Unconstrained<MerkleTree>
tree: IndexedMerkleMapN
): Promise<OffchainStateCommitments> {
return merkleUpdateBatch(
{ maxActionsPerBatch, maxActionsPerUpdate },
{ maxActionsPerProof, maxActionsPerUpdate },
stateA,
actions,
tree
Expand All @@ -221,14 +189,14 @@ function OffchainStateRollup({
// [actions, tree, proof]
privateInputs: [
ActionIterator.provable,
Unconstrained.provable,
IndexedMerkleMapN.provable,
SelfProof,
],

async method(
stateA: OffchainStateCommitments,
actions: ActionIterator,
tree: Unconstrained<MerkleTree>,
tree: IndexedMerkleMapN,
recursiveProof: Proof<
OffchainStateCommitments,
OffchainStateCommitments
Expand All @@ -247,7 +215,7 @@ function OffchainStateRollup({
let stateB = recursiveProof.publicOutput;

return merkleUpdateBatch(
{ maxActionsPerBatch, maxActionsPerUpdate },
{ maxActionsPerProof, maxActionsPerUpdate },
stateB,
actions,
tree
Expand All @@ -272,16 +240,19 @@ function OffchainStateRollup({
return result;
},

async prove(tree: MerkleTree, actions: MerkleList<MerkleList<MerkleLeaf>>) {
assert(tree.height === TREE_HEIGHT, 'Tree height must match');
async prove(
tree: IndexedMerkleMapN,
actions: MerkleList<MerkleList<MerkleLeaf>>
) {
assert(tree.height === logTotalCapacity + 1, 'Tree height must match');
if (getProofsEnabled()) await this.compile();
// clone the tree so we don't modify the input
tree = tree.clone();

// input state
let iterator = actions.startIterating();
let inputState = new OffchainStateCommitments({
root: tree.getRoot(),
root: tree.root,
actionState: iterator.currentHash,
});

Expand All @@ -304,34 +275,39 @@ function OffchainStateRollup({
updateMerkleMap(actionsList, tree);

let finalState = new OffchainStateCommitments({
root: tree.getRoot(),
root: tree.root,
actionState: iterator.hash,
});
let proof = await RollupProof.dummy(inputState, finalState, 2, 15);
return { proof, tree, nProofs: 0 };
}

// base proof
let slice = sliceActions(iterator, maxActionsPerBatch);
let proof = await offchainStateRollup.firstBatch(
inputState,
slice,
Unconstrained.from(tree)
);
let slice = sliceActions(iterator, maxActionsPerProof);
let proof = await offchainStateRollup.firstBatch(inputState, slice, tree);

// update tree root/length again, they aren't mutated :(
// TODO: this shows why the full tree should be the public output
tree.root = proof.publicOutput.root;
tree.length = Field(tree.data.get().sortedLeaves.length);

// recursive proofs
let nProofs = 1;
for (let i = 1; ; i++) {
if (iterator.isAtEnd().toBoolean()) break;
nProofs++;

let slice = sliceActions(iterator, maxActionsPerBatch);
let slice = sliceActions(iterator, maxActionsPerProof);
proof = await offchainStateRollup.nextBatch(
inputState,
slice,
Unconstrained.from(tree),
tree,
proof
);

// update tree root/length again, they aren't mutated :(
tree.root = proof.publicOutput.root;
tree.length = Field(tree.data.get().sortedLeaves.length);
}

return { proof, tree, nProofs };
Expand Down
Loading

0 comments on commit c573569

Please sign in to comment.