Skip to content

Commit

Permalink
Added support for multiple protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrylR committed Oct 4, 2024
1 parent 902925f commit 0a0f4ed
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 55 deletions.
30 changes: 21 additions & 9 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ZkitTSGenerator from "./ZkitTSGenerator";
import { normalizeName } from "../utils";

import { Formats } from "../constants";
import { CircuitArtifact, ArtifactWithPath } from "../types";
import { CircuitArtifact, ArtifactWithPath, GeneratedCircuitWrapperResult } from "../types";

/**
* `CircuitTypesGenerator` is need for generating TypeScript bindings based on circuit artifacts.
Expand Down Expand Up @@ -88,14 +88,26 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
});

const pathToGeneratedFile = path.join(this.getOutputTypesDir(), circuitTypePath);
const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile);
const preparedNodes: GeneratedCircuitWrapperResult[] = await this._returnTSDefinitionByArtifact(
circuitArtifacts[i],
pathToGeneratedFile,
);

this._saveFileContent(circuitTypePath, preparedNode);
for (const preparedNode of preparedNodes) {
circuitTypePath = path.join(path.dirname(circuitTypePath), preparedNode.className + ".ts");

typePathsToResolve.push({
circuitArtifact: circuitArtifacts[i],
pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath),
});
this._saveFileContent(circuitTypePath, preparedNode.content);

typePathsToResolve.push({
circuitArtifact: {
...circuitArtifacts[i],
circuitTemplateName:
circuitArtifacts[i].circuitTemplateName +
(circuitArtifacts[i].baseCircuitInfo.protocol.length > 1 ? preparedNode.prefix : ""),
},
pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath),
});
}
}

await this._resolveTypePaths(typePathsToResolve);
Expand Down Expand Up @@ -253,10 +265,10 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
private async _returnTSDefinitionByArtifact(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
): Promise<GeneratedCircuitWrapperResult[]> {
switch (circuitArtifact._format) {
case Formats.V1HH_ZKIT_TYPE:
return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile);
return await this._genCircuitWrappersClassContent(circuitArtifact, pathToGeneratedFile);
default:
throw new Error(`Unsupported format: ${circuitArtifact._format}`);
}
Expand Down
93 changes: 67 additions & 26 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
DefaultWrapperTemplateParams,
WrapperTemplateParams,
SignalInfo,
GeneratedCircuitWrapperResult,
} from "../types";

import { normalizeName } from "../utils";
Expand Down Expand Up @@ -67,12 +68,45 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
.join(".");
}

protected async _genCircuitWrapperClassContent(
protected async _genCircuitWrappersClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
): Promise<GeneratedCircuitWrapperResult[]> {
this._validateCircuitArtifact(circuitArtifact);

const result: GeneratedCircuitWrapperResult[] = [];

const unifiedProtocolType = new Set(circuitArtifact.baseCircuitInfo.protocol);
for (const protocolType of unifiedProtocolType) {
const content = await this._genSingleCircuitWrapperClassContent(
circuitArtifact,
pathToGeneratedFile,
protocolType,
unifiedProtocolType.size > 1,
);

result.push(content);
}

return result;
}

protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8");

const templateParams: DefaultWrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

private async _genSingleCircuitWrapperClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
protocolType: "groth16" | "plonk",
isPrefixed: boolean = false,
): Promise<GeneratedCircuitWrapperResult> {
const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

let outputCounter: number = 0;
Expand Down Expand Up @@ -116,32 +150,28 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
}

const pathToUtils = path.join(this.getOutputTypesDir(), "utils");
const circuitClassName = this._getCircuitName(circuitArtifact) + (isPrefixed ? this._getPrefix(protocolType) : "");

const templateParams: WrapperTemplateParams = {
protocolTypeName: circuitArtifact.baseCircuitInfo.protocol,
protocolImplementerName: this._getProtocolImplementerName(circuitArtifact),
proofTypeInternalName: this._getProofTypeInternalName(circuitArtifact),
circuitClassName: this._getCircuitName(circuitArtifact),
protocolTypeName: protocolType,
protocolImplementerName: this._getProtocolImplementerName(protocolType),
proofTypeInternalName: this._getProofTypeInternalName(protocolType),
circuitClassName,
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount),
publicInputs,
privateInputs,
calldataPointsType: this._getCalldataPointsType(circuitArtifact),
calldataPointsType: this._getCalldataPointsType(protocolType),
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8");

const templateParams: DefaultWrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
return {
content: await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }),
className: circuitClassName,
prefix: this._getPrefix(protocolType),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

private _getCalldataPubSignalsType(pubSignalsCount: number): string {
Expand All @@ -158,36 +188,47 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
return signal.dimension.reduce((acc: number, dim: string) => acc * Number(dim), 1);
}

private _getProtocolImplementerName(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
private _getProtocolImplementerName(protocolType: string): any {
switch (protocolType) {
case "groth16":
return "Groth16Implementer";
case "plonk":
return "PlonkImplementer";
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _getProofTypeInternalName(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
private _getProofTypeInternalName(protocolType: string): any {
switch (protocolType) {
case "groth16":
return "Groth16Proof";
case "plonk":
return "PlonkProof";
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _getCalldataPointsType(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
private _getCalldataPointsType(protocolType: string): any {
switch (protocolType) {
case "groth16":
return Groth16CalldataPointsType;
case "plonk":
return PlonkCalldataPointsType;
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _getPrefix(protocolType: string): string {
switch (protocolType) {
case "groth16":
return "Groth16";
case "plonk":
return "Plonk";
default:
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/types/circuitArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export type CircuitArtifact = {
* @param {SignalInfo[]} signals - The array of `input` and `output` signals used in the circuit.
*/
export type BaseCircuitInfo = {
protocol: "groth16" | "plonk";
protocol: ["groth16" | "plonk"];
constraintsNumber: number;
signals: SignalInfo[];
};
Expand Down
6 changes: 6 additions & 0 deletions src/types/typesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ export interface CircuitClass {
export interface TypeExtensionTemplateParams {
circuitClasses: CircuitClass[];
}

export interface GeneratedCircuitWrapperResult {
content: string;
className: string;
prefix: string;
}
20 changes: 11 additions & 9 deletions test/CircuitProofGeneration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe("Circuit Proof Generation", function () {
});

it("should generate and verify proof for Basic.circom", async () => {
const object = await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2");
const object = await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Groth16");

const circuit = new object(basicConfig);

Expand Down Expand Up @@ -73,16 +73,18 @@ describe("Circuit Proof Generation", function () {
});

it("should correctly import all of the zktype objects", async () => {
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/auth/BasicInAuth.circom:Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVoting"))();
new (await circuitTypesGenerator.getCircuitObject("EnhancedMultiplier"))();
new (await circuitTypesGenerator.getCircuitObject("Matrix"))();

await expect(circuitTypesGenerator.getCircuitObject("Multiplier3")).to.be.rejectedWith(
"Circuit Multiplier3 type does not exist.",
);
new (await circuitTypesGenerator.getCircuitObject("Multiplier2"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2Groth16"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Groth16"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/BasicInLib.circom:Multiplier2Plonk"))();
new (await circuitTypesGenerator.getCircuitObject("CredentialAtomicQueryMTPOnChainVotingGroth16"))();
new (await circuitTypesGenerator.getCircuitObject("circuits/fixture/Basic.circom:Multiplier2Plonk"))();

await expect(
circuitTypesGenerator.getCircuitObject("circuits/fixture/lib/Basic.circom:Multiplier2Groth16"),
).to.be.rejectedWith("Circuit Multiplier2Groth16 type does not exist.");
await expect(circuitTypesGenerator.getCircuitObject("test/fixture/Basic.circom:Multiplier3")).to.be.rejectedWith(
"Circuit Multiplier3 type does not exist.",
);
Expand Down
12 changes: 8 additions & 4 deletions test/CircuitTypesGenerator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import { CircuitTypesGenerator } from "../src";

describe("Circuit Types Generation", function () {
const expectedTypes = [
"core/lib/BasicInLib.circom/Multiplier2.ts",
"core/lib/BasicInLib.circom/Multiplier2Groth16.ts",
"core/lib/BasicInLib.circom/Multiplier2Plonk.ts",
"core/auth/EnhancedMultiplier.ts",
"core/auth/Matrix.ts",
"core/auth/Multiplier2.ts",
"core/Basic.circom/Multiplier2.ts",
"core/lib/BasicInLib.circom/Multiplier2.ts",
"core/CredentialAtomicQueryMTPOnChainVoting.ts",
"core/Basic.circom/Multiplier2Groth16.ts",
"core/Basic.circom/Multiplier2Plonk.ts",
"core/lib/BasicInLib.circom/Multiplier2Groth16.ts",
"core/lib/BasicInLib.circom/Multiplier2Plonk.ts",
"core/CredentialAtomicQueryMTPOnChainVotingPlonk.ts",
"core/CredentialAtomicQueryMTPOnChainVotingGroth16.ts",
];

const circuitTypesGenerator = new CircuitTypesGenerator({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/credentialAtomicQueryMTPV2OnChainVoting.circom",
"baseCircuitInfo": {
"constraintsNumber": 86791,
"protocol": "groth16",
"protocol": ["groth16", "plonk"],
"signals": [
{
"name": "merklized",
Expand Down
2 changes: 1 addition & 1 deletion test/fixture-cache/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/Basic.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"protocol": ["groth16", "plonk"],
"signals": [
{
"name": "in1",
Expand Down
2 changes: 1 addition & 1 deletion test/fixture-cache/auth/EnhancedMultiplier_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/EMultiplier.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"protocol": ["groth16"],
"signals": [
{
"name": "in1",
Expand Down
2 changes: 1 addition & 1 deletion test/fixture-cache/auth/Matrix_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/Matrix.circom",
"baseCircuitInfo": {
"constraintsNumber": 8,
"protocol": "groth16",
"protocol": ["groth16"],
"signals": [
{
"name": "a",
Expand Down
2 changes: 1 addition & 1 deletion test/fixture-cache/auth/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/BasicInAuth.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"protocol": ["plonk"],
"signals": [
{
"name": "in1",
Expand Down
2 changes: 1 addition & 1 deletion test/fixture-cache/lib/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"circuitSourceName": "circuits/fixture/lib/BasicInLib.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"protocol": ["groth16", "groth16", "plonk"],
"signals": [
{
"name": "in1",
Expand Down

0 comments on commit 0a0f4ed

Please sign in to comment.