Skip to content

Commit

Permalink
Add support for Plonk Protocol per circuit (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrylR authored Oct 8, 2024
1 parent 08c4b88 commit 7a360e8
Show file tree
Hide file tree
Showing 22 changed files with 230 additions and 57 deletions.
12 changes: 6 additions & 6 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/zktype",
"version": "0.3.1",
"version": "0.4.0-rc.0",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down Expand Up @@ -49,7 +49,7 @@
"typescript": "5.5.4"
},
"peerDependencies": {
"@solarity/zkit": "^0.2.4"
"@solarity/zkit": "^0.3.0-rc.0"
},
"devDependencies": {
"@types/chai": "^4.3.12",
Expand Down
5 changes: 5 additions & 0 deletions src/constants/protocol.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export const Groth16CalldataPointsType =
"[NumericString, NumericString], [[NumericString, NumericString], [NumericString, NumericString]], [NumericString, NumericString]";

export const PlonkCalldataPointsType =
"[NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString, NumericString]";
52 changes: 35 additions & 17 deletions src/core/CircuitTypesGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ZkitTSGenerator from "./ZkitTSGenerator";
import { normalizeName } from "../utils";

import { Formats } from "../constants";
import { CircuitArtifact, ArtifactWithPath } from "../types";
import { CircuitArtifact, ArtifactWithPath, GeneratedCircuitWrapperResult } from "../types";

/**
* `CircuitTypesGenerator` is need for generating TypeScript bindings based on circuit artifacts.
Expand All @@ -21,11 +21,15 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
/**
* Returns an object that represents the circuit class based on the circuit name.
*/
public async getCircuitObject(circuitName: string): Promise<any> {
public async getCircuitObject(circuitName: string, protocol?: string): Promise<any> {
const pathToGeneratedTypes = this.getOutputTypesDir();

const module = await import(pathToGeneratedTypes);

if (protocol) {
circuitName += this._getPrefix(protocol.toLowerCase());
}

if (!this._isFullyQualifiedCircuitName(circuitName)) {
if (!module[circuitName]) {
throw new Error(`Circuit ${circuitName} type does not exist.`);
Expand Down Expand Up @@ -88,14 +92,22 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
});

const pathToGeneratedFile = path.join(this.getOutputTypesDir(), circuitTypePath);
const preparedNode = await this._returnTSDefinitionByArtifact(circuitArtifacts[i], pathToGeneratedFile);
const preparedNodes: GeneratedCircuitWrapperResult[] = await this._returnTSDefinitionByArtifact(
circuitArtifacts[i],
pathToGeneratedFile,
);

this._saveFileContent(circuitTypePath, preparedNode);
for (const preparedNode of preparedNodes) {
circuitTypePath = path.join(path.dirname(circuitTypePath), preparedNode.className + ".ts");

typePathsToResolve.push({
circuitArtifact: circuitArtifacts[i],
pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath),
});
this._saveFileContent(circuitTypePath, preparedNode.content);

typePathsToResolve.push({
circuitArtifact: circuitArtifacts[i],
pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath),
protocol: circuitArtifacts[i].baseCircuitInfo.protocol.length > 1 ? preparedNode.prefix : undefined,
});
}
}

await this._resolveTypePaths(typePathsToResolve);
Expand All @@ -117,7 +129,7 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {

// index file path => its content
const indexFilesMap: Map<string, string[]> = new Map();
const isCircuitNameExist: Map<string, boolean> = new Map();
const isCircuitNameExist: Map<string, number> = new Map();

const topLevelCircuits: {
[circuitName: string]: ArtifactWithPath[];
Expand Down Expand Up @@ -151,19 +163,25 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
}
}

if (!isCircuitNameExist.has(typePath.circuitArtifact.circuitTemplateName)) {
const circuitName = typePath.circuitArtifact.circuitTemplateName;

if (
isCircuitNameExist.get(circuitName) === undefined ||
isCircuitNameExist.get(circuitName)! < typePath.circuitArtifact.baseCircuitInfo.protocol.length
) {
indexFilesMap.set(pathToMainIndexFile, [
...(indexFilesMap.get(pathToMainIndexFile) === undefined ? [] : indexFilesMap.get(pathToMainIndexFile)!),
this._getExportDeclarationForFile(path.relative(this._projectRoot, levels.join(path.sep))),
]);
}

isCircuitNameExist.set(typePath.circuitArtifact.circuitTemplateName, true);
isCircuitNameExist.set(
circuitName,
isCircuitNameExist.get(circuitName) === undefined ? 1 : isCircuitNameExist.get(circuitName)! + 1,
);

topLevelCircuits[typePath.circuitArtifact.circuitTemplateName] =
topLevelCircuits[typePath.circuitArtifact.circuitTemplateName] === undefined
? [typePath]
: [...topLevelCircuits[typePath.circuitArtifact.circuitTemplateName], typePath];
topLevelCircuits[circuitName] =
topLevelCircuits[circuitName] === undefined ? [typePath] : [...topLevelCircuits[circuitName], typePath];
}

for (const [absolutePath, content] of indexFilesMap) {
Expand Down Expand Up @@ -253,10 +271,10 @@ export class CircuitTypesGenerator extends ZkitTSGenerator {
private async _returnTSDefinitionByArtifact(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
): Promise<GeneratedCircuitWrapperResult[]> {
switch (circuitArtifact._format) {
case Formats.V1HH_ZKIT_TYPE:
return await this._genCircuitWrapperClassContent(circuitArtifact, pathToGeneratedFile);
return await this._genCircuitWrappersClassContent(circuitArtifact, pathToGeneratedFile);
default:
throw new Error(`Unsupported format: ${circuitArtifact._format}`);
}
Expand Down
131 changes: 118 additions & 13 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ import {
DefaultWrapperTemplateParams,
WrapperTemplateParams,
SignalInfo,
GeneratedCircuitWrapperResult,
} from "../types";

import { normalizeName } from "../utils";
import { SignalTypeNames, SignalVisibilityNames } from "../constants";
import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol";

export default class ZkitTSGenerator extends BaseTSGenerator {
protected async _genHardhatZkitTypeExtension(circuits: {
Expand All @@ -37,6 +39,23 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
circuitClasses.push({
name: this._getCircuitName(artifacts[0].circuitArtifact),
object: this._getCircuitName(artifacts[0].circuitArtifact),
protocol: artifacts[0].protocol,
});

continue;
}

if (artifacts.length === 2 && artifacts[0].protocol !== artifacts[1].protocol) {
circuitClasses.push({
name: this._getCircuitName(artifacts[0].circuitArtifact),
object: this._getCircuitName(artifacts[0].circuitArtifact) + this._getPrefix(artifacts[0].protocol!),
protocol: artifacts[0].protocol,
});

circuitClasses.push({
name: this._getCircuitName(artifacts[1].circuitArtifact),
object: this._getCircuitName(artifacts[1].circuitArtifact) + this._getPrefix(artifacts[1].protocol!),
protocol: artifacts[1].protocol,
});

continue;
Expand All @@ -46,6 +65,7 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
circuitClasses.push({
name: this._getFullCircuitName(artifact.circuitArtifact),
object: this._getObjectPath(artifact.pathToGeneratedFile),
protocol: artifact.protocol,
});
}
}
Expand All @@ -66,10 +86,45 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
.join(".");
}

protected async _genCircuitWrapperClassContent(
protected async _genCircuitWrappersClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<GeneratedCircuitWrapperResult[]> {
this._validateCircuitArtifact(circuitArtifact);

const result: GeneratedCircuitWrapperResult[] = [];

const unifiedProtocolType = new Set(circuitArtifact.baseCircuitInfo.protocol);
for (const protocolType of unifiedProtocolType) {
const content = await this._genSingleCircuitWrapperClassContent(
circuitArtifact,
pathToGeneratedFile,
protocolType,
unifiedProtocolType.size > 1,
);

result.push(content);
}

return result;
}

protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8");

const templateParams: DefaultWrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

private async _genSingleCircuitWrapperClassContent(
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
protocolType: "groth16" | "plonk",
isPrefixed: boolean = false,
): Promise<GeneratedCircuitWrapperResult> {
const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

let outputCounter: number = 0;
Expand Down Expand Up @@ -113,28 +168,28 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
}

const pathToUtils = path.join(this.getOutputTypesDir(), "utils");
const circuitClassName = this._getCircuitName(circuitArtifact) + (isPrefixed ? this._getPrefix(protocolType) : "");

const templateParams: WrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
protocolTypeName: protocolType,
protocolImplementerName: this._getProtocolImplementerName(protocolType),
proofTypeInternalName: this._getProofTypeInternalName(protocolType),
circuitClassName,
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount),
publicInputs,
privateInputs,
calldataPointsType: this._getCalldataPointsType(protocolType),
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise<string> {
const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8");

const templateParams: DefaultWrapperTemplateParams = {
circuitClassName: this._getCircuitName(circuitArtifact),
return {
content: await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }),
className: circuitClassName,
prefix: this._getPrefix(protocolType).toLowerCase(),
};

return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" });
}

private _getCalldataPubSignalsType(pubSignalsCount: number): string {
Expand All @@ -150,4 +205,54 @@ export default class ZkitTSGenerator extends BaseTSGenerator {

return signal.dimension.reduce((acc: number, dim: string) => acc * Number(dim), 1);
}

private _getProtocolImplementerName(protocolType: string): any {
switch (protocolType) {
case "groth16":
return "Groth16Implementer";
case "plonk":
return "PlonkImplementer";
default:
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _getProofTypeInternalName(protocolType: string): any {
switch (protocolType) {
case "groth16":
return "Groth16Proof";
case "plonk":
return "PlonkProof";
default:
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _getCalldataPointsType(protocolType: string): any {
switch (protocolType) {
case "groth16":
return Groth16CalldataPointsType;
case "plonk":
return PlonkCalldataPointsType;
default:
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

protected _getPrefix(protocolType: string): string {
switch (protocolType) {
case "groth16":
return "Groth16";
case "plonk":
return "Plonk";
default:
throw new Error(`Unknown protocol: ${protocolType}`);
}
}

private _validateCircuitArtifact(circuitArtifact: CircuitArtifact): void {
if (!circuitArtifact.baseCircuitInfo.protocol) {
throw new Error(`ZKType: Protocol is missing in the circuit artifact: ${circuitArtifact.circuitTemplateName}`);
}
}
}
13 changes: 7 additions & 6 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ import {
CircuitZKit,
CircuitZKitConfig,
Groth16Proof,
PlonkProof,
NumberLike,
NumericString,
PublicSignals,
Groth16Implementer,
PlonkImplementer,
} from "@solarity/zkit";

import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>";
Expand All @@ -22,20 +25,18 @@ export type <%= publicInputsTypeName %> = {
}

export type <%= proofTypeName %> = {
proof: Groth16Proof;
proof: <%= proofTypeInternalName %>;
publicSignals: <%= publicInputsTypeName %>;
}

export type Calldata = [
[NumericString, NumericString],
[[NumericString, NumericString], [NumericString, NumericString]],
[NumericString, NumericString],
<%= calldataPointsType %>,
<%= calldataPubSignalsType %>,
];

export class <%= circuitClassName %> extends CircuitZKit {
export class <%= circuitClassName %> extends CircuitZKit<"<%= protocolTypeName %>"> {
constructor(config: CircuitZKitConfig) {
super(config);
super(config, new <%= protocolImplementerName %>());
}

public async generateProof(inputs: <%= privateInputsTypeName %>): Promise<<%= proofTypeName %>> {
Expand Down
Loading

0 comments on commit 7a360e8

Please sign in to comment.