Skip to content

Commit

Permalink
Added circuit types Hardhat definition generation of all possible cir…
Browse files Browse the repository at this point in the history
…cuits
  • Loading branch information
KyrylR committed Jul 2, 2024
1 parent 17a698e commit 4d42041
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 15 deletions.
14 changes: 13 additions & 1 deletion src/core/BaseTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,19 @@ export default class BaseTSGenerator {
* @returns {string} The extracted circuit name.
*/
protected _getCircuitName(circuitArtifact: CircuitArtifact): string {
return `${circuitArtifact.circuitName.replace(path.extname(circuitArtifact.circuitName), "")}Circuit`;
return `${circuitArtifact.circuitName.replace(path.extname(circuitArtifact.circuitName), "")}`;
}

/**
* Returns the full circuit name.
*
* The full circuit name is a combination of the source name and the circuit name, separated by a colon.
*
* @param {CircuitArtifact} circuitArtifact - The circuit artifact from which the full circuit name is extracted.
* @returns {string} The full circuit name.
*/
protected _getFullCircuitName(circuitArtifact: CircuitArtifact): string {
return `${circuitArtifact.sourceName}:${this._getCircuitName(circuitArtifact)}`;
}

/**
Expand Down
23 changes: 19 additions & 4 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,19 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator {
public async getCircuitObject(circuitName: string): Promise<any> {
const pathToGeneratedTypes = path.join(this._projectRoot, this.getOutputTypesDir());

if (this._nameToObjectNameMap.size === 0) {
throw new Error("No circuit types have been generated.");
}

const module = await import(pathToGeneratedTypes);

return module[circuitName];
const circuitObjectPath = this._nameToObjectNameMap.get(circuitName);

if (!circuitObjectPath) {
throw new Error(`Circuit ${circuitName} type does not exist.`);
}

return circuitObjectPath.split(".").reduce((acc, key) => acc[key], module as any);
}

/**
Expand Down Expand Up @@ -101,7 +111,9 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator {
const indexFilesMap: Map<string, string[]> = new Map();
const isCircuitNameExist: Map<string, boolean> = new Map();

const topLevelCircuits: string[] = [];
const topLevelCircuits: {
[circuitName: string]: ArtifactWithPath[];
} = {};

for (const typePath of typePaths) {
const levels: string[] = typePath.pathToGeneratedFile
Expand Down Expand Up @@ -136,11 +148,14 @@ export default class CircuitTypesGenerator extends ZkitTSGenerator {
...(indexFilesMap.get(pathToMainIndexFile) === undefined ? [] : indexFilesMap.get(pathToMainIndexFile)!),
this._getExportDeclarationForFile(path.relative(path.join(this._projectRoot), levels.join(path.sep))),
]);

topLevelCircuits.push(this._getCircuitName(typePath.circuitArtifact));
}

isCircuitNameExist.set(typePath.circuitArtifact.circuitName, true);

topLevelCircuits[typePath.circuitArtifact.circuitName] =
topLevelCircuits[typePath.circuitArtifact.circuitName] === undefined
? [typePath]
: [...topLevelCircuits[typePath.circuitArtifact.circuitName], typePath];
}

for (const [absolutePath, content] of indexFilesMap) {
Expand Down
58 changes: 55 additions & 3 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,68 @@ import prettier from "prettier";

import BaseTSGenerator from "./BaseTSGenerator";

import { CircuitArtifact, Inputs, TypeExtensionTemplateParams, WrapperTemplateParams } from "../types";
import {
ArtifactWithPath,
CircuitArtifact,
CircuitClass,
Inputs,
TypeExtensionTemplateParams,
WrapperTemplateParams,
} from "../types";

import { normalizeName } from "../utils";
import { SignalTypeNames, SignalVisibilityNames } from "../constants";

export default class ZkitTSGenerator extends BaseTSGenerator {
protected async _genHardhatZkitTypeExtension(circuitNames: string[]): Promise<string> {
protected _nameToObjectNameMap: Map<string, string> = new Map();

protected async _genHardhatZkitTypeExtension(circuits: {
[circuitName: string]: ArtifactWithPath[];
}): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "type-extension.ts.ejs"), "utf8");

const circuitClasses: CircuitClass[] = [];

const keys = Object.keys(circuits);

const outputTypesDir = this.getOutputTypesDir();

for (let i = 0; i < keys.length; i++) {
const artifacts = circuits[keys[i]];

if (artifacts.length === 1) {
circuitClasses.push({
name: this._getCircuitName(artifacts[0].circuitArtifact),
object: this._getCircuitName(artifacts[0].circuitArtifact),
});

this._nameToObjectNameMap.set(
this._getCircuitName(artifacts[0].circuitArtifact),
this._getCircuitName(artifacts[0].circuitArtifact),
);

continue;
}

for (const artifact of artifacts) {
const objectName = path
.normalize(artifact.pathToGeneratedFile.replace(outputTypesDir, ""))
.split(path.sep)
.filter((level) => level !== "")
.map((level, index, array) => (index !== array.length - 1 ? normalizeName(level) : level.replace(".ts", "")))
.join(".");

circuitClasses.push({
name: this._getFullCircuitName(artifact.circuitArtifact),
object: objectName,
});

this._nameToObjectNameMap.set(this._getFullCircuitName(artifact.circuitArtifact), objectName);
}
}

const templateParams: TypeExtensionTemplateParams = {
circuitClassNames: circuitNames,
circuitClasses,
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
Expand Down
6 changes: 3 additions & 3 deletions src/core/templates/type-extension.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import * as Circuits from ".";

declare module "hardhat/types/runtime" {
interface HardhatZKit {
<% for (let i = 0; i < circuitClassNames.length; i++) { -%>
<% for (let i = 0; i < circuitClasses.length; i++) { -%>
getCircuit(
name: "<%= circuitClassNames[i] %>"
) : Promise<Circuits.<%= circuitClassNames[i] %>>;
name: "<%= circuitClasses[i].name %>"
) : Promise<Circuits.<%= circuitClasses[i].object %>>;
<% } -%>
}
}
7 changes: 6 additions & 1 deletion src/types/typesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ export interface WrapperTemplateParams {
circuitClassName: string;
}

export interface CircuitClass {
name: string;
object: string;
}

export interface TypeExtensionTemplateParams {
circuitClassNames: string[];
circuitClasses: CircuitClass[];
}
25 changes: 22 additions & 3 deletions test/CircuitProofGeneration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ import { generateAST } from "./helpers/generator";

import CircuitTypesGenerator from "../src/core/CircuitTypesGenerator";

describe.only("Circuit Proof Generation", function () {
describe("Circuit Proof Generation", function () {
const astDir = "test/cache/circuits-ast";

const circuitTypesGenerator = new CircuitTypesGenerator({
basePath: "test/fixture",
circuitsASTPaths: ["test/cache/circuits-ast/Basic.json"],
circuitsASTPaths: [
"test/cache/circuits-ast/Basic.json",
"test/cache/circuits-ast/credentialAtomicQueryMTPV2OnChainVoting.json",
"test/cache/circuits-ast/lib/BasicInLib.json",
"test/cache/circuits-ast/auth/BasicInAuth.json",
"test/cache/circuits-ast/auth/EMultiplier.json",
],
});

const config: CircuitZKitConfig = {
Expand All @@ -29,10 +35,23 @@ describe.only("Circuit Proof Generation", function () {
});

it("should generate and verify proof", async () => {
const object = await circuitTypesGenerator.getCircuitObject("Multiplier2Circuit");
const object = await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2");

const circuit = new object(config);

const proof = await circuit.generateProof({ in1: 2, in2: 3 });
expect(await circuit.verifyProof(proof)).to.be.true;
});

it("should correctly import all of the zktype objects", async () => {
new (await circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("test/fixture/auth/BasicInAuth.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("test/fixture/lib/BasicInLib.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVoting"))();
new (await circuitTypesGenerator.getCircuitObject("EnhancedMultiplier"))();

await expect(circuitTypesGenerator.getCircuitObject("Multiplier2")).to.be.rejectedWith(
"Circuit Multiplier2 type does not exist.",
);
});
});

0 comments on commit 4d42041

Please sign in to comment.