diff --git a/CHANGELOG.md b/CHANGELOG.md index b0cd869cc7..400815b4c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added +-`ZkProgram` to support non-pure provable types as inputs and outputs https://github.com/o1-labs/o1js/pull/1828 + - Support secp256r1 in elliptic curve and ECDSA gadgets https://github.com/o1-labs/o1js/pull/1885 ### Fixed diff --git a/src/bindings b/src/bindings index e081466d5e..acc5a7c566 160000 --- a/src/bindings +++ b/src/bindings @@ -1 +1 @@ -Subproject commit e081466d5e435fe3fbd32fa83ef5c0ef3f030dec +Subproject commit acc5a7c56645ea46f58e698bca75a011407d45b0 diff --git a/src/examples/zkprogram/program-with-non-pure-input.ts b/src/examples/zkprogram/program-with-non-pure-input.ts new file mode 100644 index 0000000000..3566584f06 --- /dev/null +++ b/src/examples/zkprogram/program-with-non-pure-input.ts @@ -0,0 +1,39 @@ +import { Field, Provable, Struct, ZkProgram, assert } from 'o1js'; + +class MyStruct extends Struct({ + label: String, + value: Field, +}) {} + +let MyProgram = ZkProgram({ + name: 'example-with-non-pure-inputs', + publicInput: MyStruct, + publicOutput: MyStruct, + + methods: { + baseCase: { + privateInputs: [], + async method(input: MyStruct) { + //update input in circuit + input.label = 'in-circuit'; + return { + publicOutput: input, + }; + }, + }, + }, +}); + +// + +console.log('compiling MyProgram...'); +await MyProgram.compile(); +console.log('compile done'); + +let input = new MyStruct({ label: 'input', value: Field(5) }); + +let { proof } = await MyProgram.baseCase(input); +let ok = await MyProgram.verify(proof); + +assert(ok, 'proof not valid!'); +assert(proof.publicOutput.label === 'in-circuit'); diff --git a/src/lib/proof-system/proof.ts b/src/lib/proof-system/proof.ts index 9111482cf9..f19402fc8d 100644 --- a/src/lib/proof-system/proof.ts +++ b/src/lib/proof-system/proof.ts @@ -2,7 +2,7 @@ import { initializeBindings, withThreadPool } from '../../snarky.js'; import { Pickles } from '../../snarky.js'; import { Field, Bool } from '../provable/wrapped.js'; import type { - FlexibleProvablePure, + FlexibleProvable, InferProvable, } from '../provable/types/struct.js'; import { FeatureFlags } from './feature-flags.js'; @@ -22,8 +22,8 @@ export { dummyProof, extractProofs, extractProofTypes, type ProofValue }; type MaxProofs = 0 | 1 | 2; class ProofBase { - static publicInputType: FlexibleProvablePure = undefined as any; - static publicOutputType: FlexibleProvablePure = undefined as any; + static publicInputType: FlexibleProvable = undefined as any; + static publicOutputType: FlexibleProvable = undefined as any; static tag: () => { name: string } = () => { throw Error( `You cannot use the \`Proof\` class directly. Instead, define a subclass:\n` + diff --git a/src/lib/proof-system/zkprogram.ts b/src/lib/proof-system/zkprogram.ts index 26b06039eb..9efb896edb 100644 --- a/src/lib/proof-system/zkprogram.ts +++ b/src/lib/proof-system/zkprogram.ts @@ -3,7 +3,7 @@ import { Snarky, initializeBindings, withThreadPool } from '../../snarky.js'; import { Pickles, Gate } from '../../snarky.js'; import { Field } from '../provable/wrapped.js'; import { - FlexibleProvablePure, + FlexibleProvable, InferProvable, ProvablePureExtended, Struct, @@ -89,19 +89,28 @@ const Void: ProvablePureExtended = EmptyVoid(); function createProgramState() { let methodCache: Map = new Map(); - return { + setNonPureOutput(value: any[]) { + methodCache.set('__nonPureOutput__', value); + }, + getNonPureOutput(): any[] { + let entry = methodCache.get('__nonPureOutput__'); + if (entry === undefined) throw Error(`Non-pure output not defined`); + return entry as any[]; + }, + setAuxiliaryOutput(value: unknown, methodName: string) { methodCache.set(methodName, value); }, + getAuxiliaryOutput(methodName: string): unknown { let entry = methodCache.get(methodName); if (entry === undefined) throw Error(`Auxiliary value for method ${methodName} not defined`); return entry; }, - reset(methodName: string) { - methodCache.delete(methodName); + reset(key: string) { + methodCache.delete(key); }, }; } @@ -173,8 +182,8 @@ let SideloadedTag = { function ZkProgram< Config extends { - publicInput?: ProvableTypePure; - publicOutput?: ProvableTypePure; + publicInput?: ProvableType; + publicOutput?: ProvableType; methods: { [I in string]: { privateInputs: Tuple; @@ -250,10 +259,10 @@ function ZkProgram< let doProving = true; let methods = config.methods; - let publicInputType: ProvablePure = ProvableType.get( + let publicInputType: Provable = ProvableType.get( config.publicInput ?? Undefined ); - let publicOutputType: ProvablePure = ProvableType.get( + let publicOutputType: Provable = ProvableType.get( config.publicOutput ?? Void ); @@ -391,10 +400,30 @@ function ZkProgram< `Try calling \`await program.compile()\` first, this will cache provers in the background.\nIf you compiled your zkProgram with proofs disabled (\`proofsEnabled = false\`), you have to compile it with proofs enabled first.` ); } - let publicInputFields = toFieldConsts(publicInputType, publicInput); + + let nonPureInputExists = + publicInputType.toAuxiliary(publicInput).length !== 0; + console.log('nonPure Input Exists', nonPureInputExists); + let publicInputFields, publicInputAux; + if (nonPureInputExists) { + // serialize publicInput into pure provable field elements and auxilary data + ({ publicInputFields, publicInputAux } = toFieldAndAuxConsts( + publicInputType, + publicInput + )); + } else { + publicInputFields = toFieldConsts(publicInputType, publicInput); + } + let previousProofs = MlArray.to(getPreviousProofsForProver(args)); - let id = snarkContext.enter({ witnesses: args, inProver: true }); + console.log('auxdata before entering snarkContext ', publicInputAux); + let id = snarkContext.enter({ + witnesses: args, + inProver: true, + auxInputData: publicInputAux, + }); + let result: UnwrapPromise>; try { result = await picklesProver(publicInputFields, previousProofs); @@ -415,8 +444,21 @@ function ZkProgram< programState.reset(methodIntfs[i].methodName); } + let publicOutput; let [publicOutputFields, proof] = MlPair.from(result); - let publicOutput = fromFieldConsts(publicOutputType, publicOutputFields); + if (nonPureInputExists) { + let nonPureOutput = programState.getNonPureOutput(); + + publicOutput = fromFieldConsts( + publicOutputType, + publicOutputFields, + nonPureOutput + ); + + programState.reset('nonPureOutput'); + } else { + publicOutput = fromFieldConsts(publicOutputType, publicOutputFields); + } return { proof: new ProgramProof({ @@ -649,8 +691,8 @@ async function compileProgram({ overrideWrapDomain, state, }: { - publicInputType: ProvablePure; - publicOutputType: ProvablePure; + publicInputType: Provable; + publicOutputType: Provable; methodIntfs: MethodInterface[]; methods: ((...args: any) => unknown)[]; gates: Gate[][]; @@ -762,7 +804,7 @@ If you are using a SmartContract, make sure you are using the @method decorator. } function analyzeMethod( - publicInputType: ProvablePure, + publicInputType: Provable, methodIntf: MethodInterface, method: (...args: any) => unknown ) { @@ -790,8 +832,8 @@ function inCircuitVkHash(inCircuitVk: unknown): Field { } function picklesRuleFromFunction( - publicInputType: ProvablePure, - publicOutputType: ProvablePure, + publicInputType: Provable, + publicOutputType: Provable, func: (...args: unknown[]) => unknown, proofSystemTag: { name: string }, { methodName, args, auxiliaryType }: MethodInterface, @@ -801,7 +843,11 @@ function picklesRuleFromFunction( async function main( publicInput: MlFieldArray ): ReturnType { - let { witnesses: argsWithoutPublicInput, inProver } = snarkContext.get(); + let { + witnesses: argsWithoutPublicInput, + inProver, + auxInputData, + } = snarkContext.get(); assert(!(inProver && argsWithoutPublicInput === undefined)); let finalArgs = []; let proofs: { @@ -837,10 +883,22 @@ function picklesRuleFromFunction( if (publicInputType === Undefined || publicInputType === Void) { result = (await func(...finalArgs)) as any; } else { - let input = fromFieldVars(publicInputType, publicInput); + console.log('auxData before input', auxInputData); + let input = fromFieldVars(publicInputType, publicInput, auxInputData); result = (await func(input, ...finalArgs)) as any; } + console.log('result input', result); + if (result?.publicOutput) { + // store the nonPure auxiliary data in program state cache if it exists + let nonPureOutput = publicOutputType.toAuxiliary(result.publicOutput); + let nonPureOutputExists = nonPureOutput.length !== 0; + + if (state !== undefined && nonPureOutputExists) { + state.setNonPureOutput(nonPureOutput); + } + } + proofs.forEach(({ Proof, proof }) => { if (!(proof instanceof DynamicProof)) return; @@ -869,7 +927,7 @@ function picklesRuleFromFunction( Pickles.sideLoaded.inCircuit(computedTag, circuitVk); }); - // if the public output is empty, we don't evaluate `toFields(result)` to allow the function to return something else in that case + // if the output is empty, we don't evaluate `toFields(result)` to allow the function to return something else in that case let hasPublicOutput = publicOutputType.sizeInFields() !== 0; let publicOutput = hasPublicOutput ? publicOutputType.toFields(result.publicOutput) @@ -957,20 +1015,40 @@ function getMaxProofsVerified(methodIntfs: MethodInterface[]) { ) as any as 0 | 1 | 2; } -function fromFieldVars(type: ProvablePure, fields: MlFieldArray) { - return type.fromFields(MlFieldArray.from(fields)); +function fromFieldVars( + type: Provable, + fields: MlFieldArray, + auxData: any[] = [] +) { + return type.fromFields(MlFieldArray.from(fields), auxData); } -function fromFieldConsts(type: ProvablePure, fields: MlFieldConstArray) { - return type.fromFields(MlFieldConstArray.from(fields)); +function toFieldVars(type: ProvablePure, value: T) { + return MlFieldArray.to(type.toFields(value)); } -function toFieldConsts(type: ProvablePure, value: T) { + +function fromFieldConsts( + type: Provable, + fields: MlFieldConstArray, + aux: any[] = [] +) { + return type.fromFields(MlFieldConstArray.from(fields), aux); +} + +function toFieldConsts(type: Provable, value: T) { return MlFieldConstArray.to(type.toFields(value)); } +function toFieldAndAuxConsts(type: Provable, value: T) { + return { + publicInputFields: MlFieldConstArray.to(type.toFields(value)), + publicInputAux: type.toAuxiliary(value), + }; +} + ZkProgram.Proof = function < - PublicInputType extends FlexibleProvablePure, - PublicOutputType extends FlexibleProvablePure + PublicInputType extends FlexibleProvable, + PublicOutputType extends FlexibleProvable >(program: { name: string; publicInputType: PublicInputType; diff --git a/src/lib/provable/core/provable-context.ts b/src/lib/provable/core/provable-context.ts index a697c5d00d..ee44ab9c5e 100644 --- a/src/lib/provable/core/provable-context.ts +++ b/src/lib/provable/core/provable-context.ts @@ -40,6 +40,7 @@ type SnarkContext = { inCheckedComputation?: boolean; inAnalyze?: boolean; inWitnessBlock?: boolean; + auxInputData?: any[]; }; let snarkContext = Context.create({ default: {} }); diff --git a/src/mina b/src/mina index 6899054b74..24c8b2d723 160000 --- a/src/mina +++ b/src/mina @@ -1 +1 @@ -Subproject commit 6899054b745c1323b9d5bcaa62c00bed2ad1ead3 +Subproject commit 24c8b2d723fb09d0d7f996b6ac35373dc27084ef