From 4d420415880380e87b5b433de7f22da43492a000 Mon Sep 17 00:00:00 2001 From: Kyryl Riabov Date: Tue, 2 Jul 2024 13:00:04 +0300 Subject: [PATCH] Added circuit types Hardhat definition generation of all possible circuits --- src/core/BaseTSGenerator.ts | 14 +++++- src/core/CircuitTypesGenerator.ts | 23 ++++++++-- src/core/ZkitTSGenerator.ts | 58 ++++++++++++++++++++++-- src/core/templates/type-extension.ts.ejs | 6 +-- src/types/typesGenerator.ts | 7 ++- test/CircuitProofGeneration.test.ts | 25 ++++++++-- 6 files changed, 118 insertions(+), 15 deletions(-) diff --git a/src/core/BaseTSGenerator.ts b/src/core/BaseTSGenerator.ts index 6484f98..2a2b87e 100644 --- a/src/core/BaseTSGenerator.ts +++ b/src/core/BaseTSGenerator.ts @@ -100,7 +100,19 @@ export default class BaseTSGenerator { * @returns {string} The extracted circuit name. */ protected _getCircuitName(circuitArtifact: CircuitArtifact): string { - return `${circuitArtifact.circuitName.replace(path.extname(circuitArtifact.circuitName), "")}Circuit`; + return `${circuitArtifact.circuitName.replace(path.extname(circuitArtifact.circuitName), "")}`; + } + + /** + * Returns the full circuit name. + * + * The full circuit name is a combination of the source name and the circuit name, separated by a colon. + * + * @param {CircuitArtifact} circuitArtifact - The circuit artifact from which the full circuit name is extracted. + * @returns {string} The full circuit name. + */ + protected _getFullCircuitName(circuitArtifact: CircuitArtifact): string { + return `${circuitArtifact.sourceName}:${this._getCircuitName(circuitArtifact)}`; } /** diff --git a/src/core/CircuitTypesGenerator.ts b/src/core/CircuitTypesGenerator.ts index 3972a39..eb7c888 100644 --- a/src/core/CircuitTypesGenerator.ts +++ b/src/core/CircuitTypesGenerator.ts @@ -24,9 +24,19 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator { public async getCircuitObject(circuitName: string): Promise { const pathToGeneratedTypes = path.join(this._projectRoot, this.getOutputTypesDir()); + if (this._nameToObjectNameMap.size === 0) { + throw new Error("No circuit types have been generated."); + } + const module = await import(pathToGeneratedTypes); - return module[circuitName]; + const circuitObjectPath = this._nameToObjectNameMap.get(circuitName); + + if (!circuitObjectPath) { + throw new Error(`Circuit ${circuitName} type does not exist.`); + } + + return circuitObjectPath.split(".").reduce((acc, key) => acc[key], module as any); } /** @@ -101,7 +111,9 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator { const indexFilesMap: Map = new Map(); const isCircuitNameExist: Map = new Map(); - const topLevelCircuits: string[] = []; + const topLevelCircuits: { + [circuitName: string]: ArtifactWithPath[]; + } = {}; for (const typePath of typePaths) { const levels: string[] = typePath.pathToGeneratedFile @@ -136,11 +148,14 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator { ...(indexFilesMap.get(pathToMainIndexFile) === undefined ? [] : indexFilesMap.get(pathToMainIndexFile)!), this._getExportDeclarationForFile(path.relative(path.join(this._projectRoot), levels.join(path.sep))), ]); - - topLevelCircuits.push(this._getCircuitName(typePath.circuitArtifact)); } isCircuitNameExist.set(typePath.circuitArtifact.circuitName, true); + + topLevelCircuits[typePath.circuitArtifact.circuitName] = + topLevelCircuits[typePath.circuitArtifact.circuitName] === undefined + ? [typePath] + : [...topLevelCircuits[typePath.circuitArtifact.circuitName], typePath]; } for (const [absolutePath, content] of indexFilesMap) { diff --git a/src/core/ZkitTSGenerator.ts b/src/core/ZkitTSGenerator.ts index d7baf2a..2374183 100644 --- a/src/core/ZkitTSGenerator.ts +++ b/src/core/ZkitTSGenerator.ts @@ -5,16 +5,68 @@ import prettier from "prettier"; import BaseTSGenerator from "./BaseTSGenerator"; -import { CircuitArtifact, Inputs, TypeExtensionTemplateParams, WrapperTemplateParams } from "../types"; +import { + ArtifactWithPath, + CircuitArtifact, + CircuitClass, + Inputs, + TypeExtensionTemplateParams, + WrapperTemplateParams, +} from "../types"; +import { normalizeName } from "../utils"; import { SignalTypeNames, SignalVisibilityNames } from "../constants"; export default class ZkitTSGenerator extends BaseTSGenerator { - protected async _genHardhatZkitTypeExtension(circuitNames: string[]): Promise { + protected _nameToObjectNameMap: Map = new Map(); + + protected async _genHardhatZkitTypeExtension(circuits: { + [circuitName: string]: ArtifactWithPath[]; + }): Promise { const template = fs.readFileSync(path.join(__dirname, "templates", "type-extension.ts.ejs"), "utf8"); + const circuitClasses: CircuitClass[] = []; + + const keys = Object.keys(circuits); + + const outputTypesDir = this.getOutputTypesDir(); + + for (let i = 0; i < keys.length; i++) { + const artifacts = circuits[keys[i]]; + + if (artifacts.length === 1) { + circuitClasses.push({ + name: this._getCircuitName(artifacts[0].circuitArtifact), + object: this._getCircuitName(artifacts[0].circuitArtifact), + }); + + this._nameToObjectNameMap.set( + this._getCircuitName(artifacts[0].circuitArtifact), + this._getCircuitName(artifacts[0].circuitArtifact), + ); + + continue; + } + + for (const artifact of artifacts) { + const objectName = path + .normalize(artifact.pathToGeneratedFile.replace(outputTypesDir, "")) + .split(path.sep) + .filter((level) => level !== "") + .map((level, index, array) => (index !== array.length - 1 ? normalizeName(level) : level.replace(".ts", ""))) + .join("."); + + circuitClasses.push({ + name: this._getFullCircuitName(artifact.circuitArtifact), + object: objectName, + }); + + this._nameToObjectNameMap.set(this._getFullCircuitName(artifact.circuitArtifact), objectName); + } + } + const templateParams: TypeExtensionTemplateParams = { - circuitClassNames: circuitNames, + circuitClasses, }; return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); diff --git a/src/core/templates/type-extension.ts.ejs b/src/core/templates/type-extension.ts.ejs index 64cedde..2f46ab0 100644 --- a/src/core/templates/type-extension.ts.ejs +++ b/src/core/templates/type-extension.ts.ejs @@ -2,10 +2,10 @@ import * as Circuits from "."; declare module "hardhat/types/runtime" { interface HardhatZKit { - <% for (let i = 0; i < circuitClassNames.length; i++) { -%> + <% for (let i = 0; i < circuitClasses.length; i++) { -%> getCircuit( - name: "<%= circuitClassNames[i] %>" - ) : Promise>; + name: "<%= circuitClasses[i].name %>" + ) : Promise>; <% } -%> } } diff --git a/src/types/typesGenerator.ts b/src/types/typesGenerator.ts index 9bd6000..6b7f75f 100644 --- a/src/types/typesGenerator.ts +++ b/src/types/typesGenerator.ts @@ -19,6 +19,11 @@ export interface WrapperTemplateParams { circuitClassName: string; } +export interface CircuitClass { + name: string; + object: string; +} + export interface TypeExtensionTemplateParams { - circuitClassNames: string[]; + circuitClasses: CircuitClass[]; } diff --git a/test/CircuitProofGeneration.test.ts b/test/CircuitProofGeneration.test.ts index 2334fc6..5b08df9 100644 --- a/test/CircuitProofGeneration.test.ts +++ b/test/CircuitProofGeneration.test.ts @@ -8,12 +8,18 @@ import { generateAST } from "./helpers/generator"; import CircuitTypesGenerator from "../src/core/CircuitTypesGenerator"; -describe.only("Circuit Proof Generation", function () { +describe("Circuit Proof Generation", function () { const astDir = "test/cache/circuits-ast"; const circuitTypesGenerator = new CircuitTypesGenerator({ basePath: "test/fixture", - circuitsASTPaths: ["test/cache/circuits-ast/Basic.json"], + circuitsASTPaths: [ + "test/cache/circuits-ast/Basic.json", + "test/cache/circuits-ast/credentialAtomicQueryMTPV2OnChainVoting.json", + "test/cache/circuits-ast/lib/BasicInLib.json", + "test/cache/circuits-ast/auth/BasicInAuth.json", + "test/cache/circuits-ast/auth/EMultiplier.json", + ], }); const config: CircuitZKitConfig = { @@ -29,10 +35,23 @@ describe.only("Circuit Proof Generation", function () { }); it("should generate and verify proof", async () => { - const object = await circuitTypesGenerator.getCircuitObject("Multiplier2Circuit"); + const object = await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2"); + const circuit = new object(config); const proof = await circuit.generateProof({ in1: 2, in2: 3 }); expect(await circuit.verifyProof(proof)).to.be.true; }); + + it("should correctly import all of the zktype objects", async () => { + new (await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2"))(); + new (await circuitTypesGenerator.getCircuitObject("test/fixture/auth/BasicInAuth.circom:Multiplier2"))(); + new (await circuitTypesGenerator.getCircuitObject("test/fixture/lib/BasicInLib.circom:Multiplier2"))(); + new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVoting"))(); + new (await circuitTypesGenerator.getCircuitObject("EnhancedMultiplier"))(); + + await expect(circuitTypesGenerator.getCircuitObject("Multiplier2")).to.be.rejectedWith( + "Circuit Multiplier2 type does not exist.", + ); + }); });