From 0a0f4ed2fe1c8606f7f748ce9daf2a8f400dd5be Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Fri, 4 Oct 2024 16:04:01 +0300 Subject: [PATCH] Added support for multiple protocols --- src/core/CircuitTypesGenerator.ts | 30 ++++-- src/core/ZkitTSGenerator.ts | 93 +++++++++++++------ src/types/circuitArtifact.ts | 2 +- src/types/typesGenerator.ts | 6 ++ test/CircuitProofGeneration.test.ts | 20 ++-- test/CircuitTypesGenerator.test.ts | 12 ++- ...AtomicQueryMTPOnChainVoting_artifacts.json | 2 +- test/fixture-cache/Multiplier2_artifacts.json | 2 +- .../auth/EnhancedMultiplier_artifacts.json | 2 +- test/fixture-cache/auth/Matrix_artifacts.json | 2 +- .../auth/Multiplier2_artifacts.json | 2 +- .../lib/Multiplier2_artifacts.json | 2 +- 12 files changed, 120 insertions(+), 55 deletions(-) diff --git a/src/core/CircuitTypesGenerator.ts b/src/core/CircuitTypesGenerator.ts index d7d5d64..2c15782 100644 --- a/src/core/CircuitTypesGenerator.ts +++ b/src/core/CircuitTypesGenerator.ts @@ -7,7 +7,7 @@ import ZkitTSGenerator from "./ZkitTSGenerator"; import { normalizeName } from "../utils"; import { Formats } from "../constants"; -import { CircuitArtifact, ArtifactWithPath } from "../types"; +import { CircuitArtifact, ArtifactWithPath, GeneratedCircuitWrapperResult } from "../types"; /** * `CircuitTypesGenerator` is need for generating TypeScript bindings based on circuit artifacts. @@ -88,14 +88,26 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { }); const pathToGeneratedFile = path.join(this.getOutputTypesDir(), circuitTypePath); - const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile); + const preparedNodes: GeneratedCircuitWrapperResult[] = await this._returnTSDefinitionByArtifact( + circuitArtifacts[i], + pathToGeneratedFile, + ); - this._saveFileContent(circuitTypePath, preparedNode); + for (const preparedNode of preparedNodes) { + circuitTypePath = path.join(path.dirname(circuitTypePath), preparedNode.className + ".ts"); - typePathsToResolve.push({ - circuitArtifact: circuitArtifacts[i], - pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath), - }); + this._saveFileContent(circuitTypePath, preparedNode.content); + + typePathsToResolve.push({ + circuitArtifact: { + ...circuitArtifacts[i], + circuitTemplateName: + circuitArtifacts[i].circuitTemplateName + + (circuitArtifacts[i].baseCircuitInfo.protocol.length > 1 ? preparedNode.prefix : ""), + }, + pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath), + }); + } } await this._resolveTypePaths(typePathsToResolve); @@ -253,10 +265,10 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { private async _returnTSDefinitionByArtifact( circuitArtifact: CircuitArtifact, pathToGeneratedFile: string, - ): Promise { + ): Promise { switch (circuitArtifact._format) { case Formats.V1HH_ZKIT_TYPE: - return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile); + return await this._genCircuitWrappersClassContent(circuitArtifact, pathToGeneratedFile); default: throw new Error(`Unsupported format: ${circuitArtifact._format}`); } diff --git a/src/core/ZkitTSGenerator.ts b/src/core/ZkitTSGenerator.ts index 47444d3..b3705fa 100644 --- a/src/core/ZkitTSGenerator.ts +++ b/src/core/ZkitTSGenerator.ts @@ -15,6 +15,7 @@ import { DefaultWrapperTemplateParams, WrapperTemplateParams, SignalInfo, + GeneratedCircuitWrapperResult, } from "../types"; import { normalizeName } from "../utils"; @@ -67,12 +68,45 @@ export default class ZkitTSGenerator extends BaseTSGenerator { .join("."); } - protected async _genCircuitWrapperClassContent( + protected async _genCircuitWrappersClassContent( circuitArtifact: CircuitArtifact, pathToGeneratedFile: string, - ): Promise { + ): Promise { this._validateCircuitArtifact(circuitArtifact); + const result: GeneratedCircuitWrapperResult[] = []; + + const unifiedProtocolType = new Set(circuitArtifact.baseCircuitInfo.protocol); + for (const protocolType of unifiedProtocolType) { + const content = await this._genSingleCircuitWrapperClassContent( + circuitArtifact, + pathToGeneratedFile, + protocolType, + unifiedProtocolType.size > 1, + ); + + result.push(content); + } + + return result; + } + + protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise { + const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8"); + + const templateParams: DefaultWrapperTemplateParams = { + circuitClassName: this._getCircuitName(circuitArtifact), + }; + + return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); + } + + private async _genSingleCircuitWrapperClassContent( + circuitArtifact: CircuitArtifact, + pathToGeneratedFile: string, + protocolType: "groth16" | "plonk", + isPrefixed: boolean = false, + ): Promise { const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8"); let outputCounter: number = 0; @@ -116,32 +150,28 @@ export default class ZkitTSGenerator extends BaseTSGenerator { } const pathToUtils = path.join(this.getOutputTypesDir(), "utils"); + const circuitClassName = this._getCircuitName(circuitArtifact) + (isPrefixed ? this._getPrefix(protocolType) : ""); + const templateParams: WrapperTemplateParams = { - protocolTypeName: circuitArtifact.baseCircuitInfo.protocol, - protocolImplementerName: this._getProtocolImplementerName(circuitArtifact), - proofTypeInternalName: this._getProofTypeInternalName(circuitArtifact), - circuitClassName: this._getCircuitName(circuitArtifact), + protocolTypeName: protocolType, + protocolImplementerName: this._getProtocolImplementerName(protocolType), + proofTypeInternalName: this._getProofTypeInternalName(protocolType), + circuitClassName, publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"), calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount), publicInputs, privateInputs, - calldataPointsType: this._getCalldataPointsType(circuitArtifact), + calldataPointsType: this._getCalldataPointsType(protocolType), proofTypeName: this._getTypeName(circuitArtifact, "Proof"), privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"), pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils), }; - return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); - } - - protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise { - const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8"); - - const templateParams: DefaultWrapperTemplateParams = { - circuitClassName: this._getCircuitName(circuitArtifact), + return { + content: await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }), + className: circuitClassName, + prefix: this._getPrefix(protocolType), }; - - return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); } private _getCalldataPubSignalsType(pubSignalsCount: number): string { @@ -158,36 +188,47 @@ export default class ZkitTSGenerator extends BaseTSGenerator { return signal.dimension.reduce((acc: number, dim: string) => acc * Number(dim), 1); } - private _getProtocolImplementerName(circuitArtifact: CircuitArtifact): any { - switch (circuitArtifact.baseCircuitInfo.protocol) { + private _getProtocolImplementerName(protocolType: string): any { + switch (protocolType) { case "groth16": return "Groth16Implementer"; case "plonk": return "PlonkImplementer"; default: - throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`); + throw new Error(`Unknown protocol: ${protocolType}`); } } - private _getProofTypeInternalName(circuitArtifact: CircuitArtifact): any { - switch (circuitArtifact.baseCircuitInfo.protocol) { + private _getProofTypeInternalName(protocolType: string): any { + switch (protocolType) { case "groth16": return "Groth16Proof"; case "plonk": return "PlonkProof"; default: - throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`); + throw new Error(`Unknown protocol: ${protocolType}`); } } - private _getCalldataPointsType(circuitArtifact: CircuitArtifact): any { - switch (circuitArtifact.baseCircuitInfo.protocol) { + private _getCalldataPointsType(protocolType: string): any { + switch (protocolType) { case "groth16": return Groth16CalldataPointsType; case "plonk": return PlonkCalldataPointsType; default: - throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`); + throw new Error(`Unknown protocol: ${protocolType}`); + } + } + + private _getPrefix(protocolType: string): string { + switch (protocolType) { + case "groth16": + return "Groth16"; + case "plonk": + return "Plonk"; + default: + throw new Error(`Unknown protocol: ${protocolType}`); } } diff --git a/src/types/circuitArtifact.ts b/src/types/circuitArtifact.ts index 977afec..cc880ad 100644 --- a/src/types/circuitArtifact.ts +++ b/src/types/circuitArtifact.ts @@ -30,7 +30,7 @@ export type CircuitArtifact = { * @param {SignalInfo[]} signals - The array of `input` and `output` signals used in the circuit. */ export type BaseCircuitInfo = { - protocol: "groth16" | "plonk"; + protocol: ["groth16" | "plonk"]; constraintsNumber: number; signals: SignalInfo[]; }; diff --git a/src/types/typesGenerator.ts b/src/types/typesGenerator.ts index 965fc13..8924626 100644 --- a/src/types/typesGenerator.ts +++ b/src/types/typesGenerator.ts @@ -39,3 +39,9 @@ export interface CircuitClass { export interface TypeExtensionTemplateParams { circuitClasses: CircuitClass[]; } + +export interface GeneratedCircuitWrapperResult { + content: string; + className: string; + prefix: string; +} diff --git a/test/CircuitProofGeneration.test.ts b/test/CircuitProofGeneration.test.ts index abd31d2..65fdad9 100644 --- a/test/CircuitProofGeneration.test.ts +++ b/test/CircuitProofGeneration.test.ts @@ -36,7 +36,7 @@ describe("Circuit Proof Generation", function () { }); it("should generate and verify proof for Basic.circom", async () => { - const object = await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2"); + const object = await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Groth16"); const circuit = new object(basicConfig); @@ -73,16 +73,18 @@ describe("Circuit Proof Generation", function () { }); it("should correctly import all of the zktype objects", async () => { - new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2"))(); - new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2"))(); - new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/auth/BasicInAuth.circom:Multiplier2"))(); - new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVoting"))(); new (await circuitTypesGenerator.getCircuitObject("EnhancedMultiplier"))(); new (await circuitTypesGenerator.getCircuitObject("Matrix"))(); - - await expect(circuitTypesGenerator.getCircuitObject("Multiplier3")).to.be.rejectedWith( - "Circuit Multiplier3 type does not exist.", - ); + new (await circuitTypesGenerator.getCircuitObject("Multiplier2"))(); + new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2Groth16"))(); + new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Groth16"))(); + new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2Plonk"))(); + new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVotingGroth16"))(); + new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Plonk"))(); + + await expect( + circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/Basic.circom:Multiplier2Groth16"), + ).to.be.rejectedWith("Circuit Multiplier2Groth16 type does not exist."); await expect(circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier3")).to.be.rejectedWith( "Circuit Multiplier3 type does not exist.", ); diff --git a/test/CircuitTypesGenerator.test.ts b/test/CircuitTypesGenerator.test.ts index 6aa3905..2abc484 100644 --- a/test/CircuitTypesGenerator.test.ts +++ b/test/CircuitTypesGenerator.test.ts @@ -9,13 +9,17 @@ import { CircuitTypesGenerator } from "../src"; describe("Circuit Types Generation", function () { const expectedTypes = [ - "core/lib/BasicInLib.circom/Multiplier2.ts", + "core/lib/BasicInLib.circom/Multiplier2Groth16.ts", + "core/lib/BasicInLib.circom/Multiplier2Plonk.ts", "core/auth/EnhancedMultiplier.ts", "core/auth/Matrix.ts", "core/auth/Multiplier2.ts", - "core/Basic.circom/Multiplier2.ts", - "core/lib/BasicInLib.circom/Multiplier2.ts", - "core/CredentialAtomicQueryMTPOnChainVoting.ts", + "core/Basic.circom/Multiplier2Groth16.ts", + "core/Basic.circom/Multiplier2Plonk.ts", + "core/lib/BasicInLib.circom/Multiplier2Groth16.ts", + "core/lib/BasicInLib.circom/Multiplier2Plonk.ts", + "core/CredentialAtomicQueryMTPOnChainVotingPlonk.ts", + "core/CredentialAtomicQueryMTPOnChainVotingGroth16.ts", ]; const circuitTypesGenerator = new CircuitTypesGenerator({ diff --git a/test/fixture-cache/CredentialAtomicQueryMTPOnChainVoting_artifacts.json b/test/fixture-cache/CredentialAtomicQueryMTPOnChainVoting_artifacts.json index 5fef3d8..5afb57d 100644 --- a/test/fixture-cache/CredentialAtomicQueryMTPOnChainVoting_artifacts.json +++ b/test/fixture-cache/CredentialAtomicQueryMTPOnChainVoting_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/credentialAtomicQueryMTPV2OnChainVoting.circom", "baseCircuitInfo": { "constraintsNumber": 86791, - "protocol": "groth16", + "protocol": ["groth16", "plonk"], "signals": [ { "name": "merklized", diff --git a/test/fixture-cache/Multiplier2_artifacts.json b/test/fixture-cache/Multiplier2_artifacts.json index efa2caa..5913629 100644 --- a/test/fixture-cache/Multiplier2_artifacts.json +++ b/test/fixture-cache/Multiplier2_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/Basic.circom", "baseCircuitInfo": { "constraintsNumber": 1, - "protocol": "groth16", + "protocol": ["groth16", "plonk"], "signals": [ { "name": "in1", diff --git a/test/fixture-cache/auth/EnhancedMultiplier_artifacts.json b/test/fixture-cache/auth/EnhancedMultiplier_artifacts.json index 9428b41..739c12d 100644 --- a/test/fixture-cache/auth/EnhancedMultiplier_artifacts.json +++ b/test/fixture-cache/auth/EnhancedMultiplier_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/auth/EMultiplier.circom", "baseCircuitInfo": { "constraintsNumber": 1, - "protocol": "groth16", + "protocol": ["groth16"], "signals": [ { "name": "in1", diff --git a/test/fixture-cache/auth/Matrix_artifacts.json b/test/fixture-cache/auth/Matrix_artifacts.json index 6b6dfb0..145a49b 100644 --- a/test/fixture-cache/auth/Matrix_artifacts.json +++ b/test/fixture-cache/auth/Matrix_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/auth/Matrix.circom", "baseCircuitInfo": { "constraintsNumber": 8, - "protocol": "groth16", + "protocol": ["groth16"], "signals": [ { "name": "a", diff --git a/test/fixture-cache/auth/Multiplier2_artifacts.json b/test/fixture-cache/auth/Multiplier2_artifacts.json index f66efde..79f8a6a 100644 --- a/test/fixture-cache/auth/Multiplier2_artifacts.json +++ b/test/fixture-cache/auth/Multiplier2_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/auth/BasicInAuth.circom", "baseCircuitInfo": { "constraintsNumber": 1, - "protocol": "groth16", + "protocol": ["plonk"], "signals": [ { "name": "in1", diff --git a/test/fixture-cache/lib/Multiplier2_artifacts.json b/test/fixture-cache/lib/Multiplier2_artifacts.json index 4a3ad5f..32c1888 100644 --- a/test/fixture-cache/lib/Multiplier2_artifacts.json +++ b/test/fixture-cache/lib/Multiplier2_artifacts.json @@ -5,7 +5,7 @@ "circuitSourceName": "circuits/fixture/lib/BasicInLib.circom", "baseCircuitInfo": { "constraintsNumber": 1, - "protocol": "groth16", + "protocol": ["groth16", "groth16", "plonk"], "signals": [ { "name": "in1",