diff --git a/package-lock.json b/package-lock.json index 5cbf403..1ca29f4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@solarity/zktype", - "version": "0.4.0-rc.0", + "version": "0.4.0-rc.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@solarity/zktype", - "version": "0.4.0-rc.0", + "version": "0.4.0-rc.1", "license": "MIT", "dependencies": { "ejs": "3.1.10", diff --git a/package.json b/package.json index 3bb3153..5d26817 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@solarity/zktype", - "version": "0.4.0-rc.0", + "version": "0.4.0-rc.1", "description": "Unleash TypeScript bindings for Circom circuits", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/src/core/BaseTSGenerator.ts b/src/core/BaseTSGenerator.ts index e6051b9..26304be 100644 --- a/src/core/BaseTSGenerator.ts +++ b/src/core/BaseTSGenerator.ts @@ -70,11 +70,12 @@ export default class BaseTSGenerator { * Extracts the type name from the circuit artifact. * * @param {CircuitArtifact} circuitArtifact - The circuit artifact from which the type name is extracted. + * @param protocolType - The protocol type to be added to the type name. * @param {string} [prefix=""] - The prefix to be added to the type name. * @returns {string} The extracted type name. */ - protected _getTypeName(circuitArtifact: CircuitArtifact, prefix: string = ""): string { - return `${prefix}${circuitArtifact.circuitTemplateName.replace(path.extname(circuitArtifact.circuitTemplateName), "")}`; + protected _getTypeName(circuitArtifact: CircuitArtifact, protocolType: string, prefix: string = ""): string { + return `${prefix}${circuitArtifact.circuitTemplateName.replace(path.extname(circuitArtifact.circuitTemplateName), "")}${protocolType}`; } /** diff --git a/src/core/CircuitTypesGenerator.ts b/src/core/CircuitTypesGenerator.ts index 0857bc4..bceb80e 100644 --- a/src/core/CircuitTypesGenerator.ts +++ b/src/core/CircuitTypesGenerator.ts @@ -7,7 +7,7 @@ import ZkitTSGenerator from "./ZkitTSGenerator"; import { normalizeName } from "../utils"; import { Formats } from "../constants"; -import { CircuitArtifact, ArtifactWithPath, GeneratedCircuitWrapperResult } from "../types"; +import { CircuitArtifact, GeneratedCircuitWrapperResult, CircuitSet } from "../types"; /** * `CircuitTypesGenerator` is need for generating TypeScript bindings based on circuit artifacts. @@ -65,7 +65,7 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { fs.mkdirSync(this.getOutputTypesDir(), { recursive: true }); const isNameExist: Map = new Map(); - const typePathsToResolve: ArtifactWithPath[] = []; + const circuitSet: CircuitSet = {}; for (let i = 0; i < circuitArtifacts.length; i++) { const circuitName = circuitArtifacts[i].circuitTemplateName; @@ -102,15 +102,21 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { this._saveFileContent(circuitTypePath, preparedNode.content); - typePathsToResolve.push({ + if (!circuitSet[circuitName]) { + circuitSet[circuitName] = []; + } + + circuitSet[circuitName].push({ circuitArtifact: circuitArtifacts[i], pathToGeneratedFile: path.join(this.getOutputTypesDir(), circuitTypePath), - protocol: circuitArtifacts[i].baseCircuitInfo.protocol.length > 1 ? preparedNode.prefix : undefined, + protocol: preparedNode.protocol, }); } } - await this._resolveTypePaths(typePathsToResolve); + await this._resolveTypePaths(circuitSet); + await this._saveMainIndexFile(circuitSet); + await this._saveHardhatZkitTypeExtensionFile(circuitSet); // copy utils to types output dir const utilsDirPath = this.getOutputTypesDir(); @@ -119,81 +125,87 @@ export class CircuitTypesGenerator extends ZkitTSGenerator { } /** - * Generates the index files in the `TYPES_DIR` directory and its subdirectories. - * - * @param {ArtifactWithPath[]} typePaths - The paths to the generated files and the corresponding circuit artifacts. + * Generates the index files in the subdirectories of the `TYPES_DIR` directory. */ - private async _resolveTypePaths(typePaths: ArtifactWithPath[]): Promise { + private async _resolveTypePaths(circuitSet: CircuitSet): Promise { const rootTypesDirPath = this.getOutputTypesDir(); - const pathToMainIndexFile = path.join(rootTypesDirPath, "index.ts"); // index file path => its content - const indexFilesMap: Map = new Map(); - const isCircuitNameExist: Map = new Map(); - - const topLevelCircuits: { - [circuitName: string]: ArtifactWithPath[]; - } = {}; - - for (const typePath of typePaths) { - const levels: string[] = typePath.pathToGeneratedFile - .replace(this.getOutputTypesDir(), "") - .split(path.sep) - .filter((level) => level !== ""); - - for (let i = 0; i < levels.length; i++) { - const pathToIndexFile = - i === 0 - ? path.join(rootTypesDirPath, "index.ts") - : path.join(rootTypesDirPath, levels.slice(0, i).join(path.sep), "index.ts"); - - const exportDeclaration = - path.extname(levels[i]) === ".ts" - ? this._getExportDeclarationForFile(levels[i]) - : this._getExportDeclarationForDirectory(levels[i]); - - if ( - indexFilesMap.get(pathToIndexFile) === undefined || - !indexFilesMap.get(pathToIndexFile)?.includes(exportDeclaration) - ) { - indexFilesMap.set(pathToIndexFile, [ - ...(indexFilesMap.get(pathToIndexFile) === undefined ? [] : indexFilesMap.get(pathToIndexFile)!), - exportDeclaration, - ]); + const indexFilesMap: Map> = new Map(); + + for (const [, artifactWithPaths] of Object.entries(circuitSet)) { + for (const artifactWithPath of artifactWithPaths) { + const levels: string[] = artifactWithPath.pathToGeneratedFile + .replace(this.getOutputTypesDir(), "") + .split(path.sep) + .filter((level) => level !== ""); + + for (let i = 1; i < levels.length; i++) { + const pathToIndexFile = path.join(rootTypesDirPath, levels.slice(0, i).join(path.sep), "index.ts"); + + if (!indexFilesMap.has(pathToIndexFile)) { + indexFilesMap.set(pathToIndexFile, new Set()); + } + + const exportDeclaration = + path.extname(levels[i]) === ".ts" + ? this._getExportDeclarationForFile(levels[i]) + : this._getExportDeclarationForDirectory(levels[i]); + + if ( + indexFilesMap.get(pathToIndexFile) === undefined || + !indexFilesMap.get(pathToIndexFile)!.has(exportDeclaration) + ) { + indexFilesMap.set(pathToIndexFile, indexFilesMap.get(pathToIndexFile)!.add(exportDeclaration)); + } } } + } + + for (const [absolutePath, content] of indexFilesMap) { + this._saveFileContent(path.relative(this.getOutputTypesDir(), absolutePath), Array.from(content).join("\n")); + } + } + + private async _saveMainIndexFile(circuitSet: CircuitSet): Promise { + let mainIndexFileContent = this._getExportDeclarationForDirectory(CircuitTypesGenerator.DOMAIN_SEPARATOR) + "\n"; + + for (const [, artifactWithPaths] of Object.entries(circuitSet)) { + let isCircuitNameOverlaps = false; + const seenProtocols: string[] = []; + + for (const artifactWithPath of artifactWithPaths) { + if (seenProtocols.includes(artifactWithPath.protocol)) { + isCircuitNameOverlaps = true; + break; + } - const circuitName = typePath.circuitArtifact.circuitTemplateName; + seenProtocols.push(artifactWithPath.protocol); + } - if ( - isCircuitNameExist.get(circuitName) === undefined || - isCircuitNameExist.get(circuitName)! < typePath.circuitArtifact.baseCircuitInfo.protocol.length - ) { - indexFilesMap.set(pathToMainIndexFile, [ - ...(indexFilesMap.get(pathToMainIndexFile) === undefined ? [] : indexFilesMap.get(pathToMainIndexFile)!), - this._getExportDeclarationForFile(path.relative(this._projectRoot, levels.join(path.sep))), - ]); + if (isCircuitNameOverlaps) { + continue; } - isCircuitNameExist.set( - circuitName, - isCircuitNameExist.get(circuitName) === undefined ? 1 : isCircuitNameExist.get(circuitName)! + 1, - ); + for (const artifactWithPath of artifactWithPaths) { + const levels: string[] = artifactWithPath.pathToGeneratedFile + .replace(this.getOutputTypesDir(), "") + .split(path.sep) + .filter((level) => level !== ""); - topLevelCircuits[circuitName] = - topLevelCircuits[circuitName] === undefined ? [typePath] : [...topLevelCircuits[circuitName], typePath]; - } + const exportPathToCircuitType = this._getExportDeclarationForFile( + path.relative(this._projectRoot, levels.join(path.sep)), + ); - for (const [absolutePath, content] of indexFilesMap) { - this._saveFileContent(path.relative(this.getOutputTypesDir(), absolutePath), content.join("\n")); + mainIndexFileContent += exportPathToCircuitType + "\n"; + } } - const pathToTypesExtensionFile = path.join(rootTypesDirPath, "hardhat.d.ts"); + this._saveFileContent("index.ts", mainIndexFileContent); + } - this._saveFileContent( - path.relative(this.getOutputTypesDir(), pathToTypesExtensionFile), - await this._genHardhatZkitTypeExtension(topLevelCircuits), - ); + private async _saveHardhatZkitTypeExtensionFile(circuitSet: CircuitSet): Promise { + this._saveFileContent("hardhat.d.ts", await this._genHardhatZkitTypeExtension(circuitSet)); } /** diff --git a/src/core/ZkitTSGenerator.ts b/src/core/ZkitTSGenerator.ts index 247d4ab..3ebfd7f 100644 --- a/src/core/ZkitTSGenerator.ts +++ b/src/core/ZkitTSGenerator.ts @@ -7,15 +7,15 @@ import prettier from "prettier"; import BaseTSGenerator from "./BaseTSGenerator"; import { - ArtifactWithPath, CircuitArtifact, CircuitClass, Inputs, TypeExtensionTemplateParams, - DefaultWrapperTemplateParams, WrapperTemplateParams, SignalInfo, GeneratedCircuitWrapperResult, + CircuitSet, + ProtocolType, } from "../types"; import { normalizeName } from "../utils"; @@ -23,9 +23,7 @@ import { SignalTypeNames, SignalVisibilityNames } from "../constants"; import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol"; export default class ZkitTSGenerator extends BaseTSGenerator { - protected async _genHardhatZkitTypeExtension(circuits: { - [circuitName: string]: ArtifactWithPath[]; - }): Promise { + protected async _genHardhatZkitTypeExtension(circuits: CircuitSet): Promise { const template = fs.readFileSync(path.join(__dirname, "templates", "type-extension.ts.ejs"), "utf8"); const circuitClasses: CircuitClass[] = []; @@ -90,11 +88,9 @@ export default class ZkitTSGenerator extends BaseTSGenerator { circuitArtifact: CircuitArtifact, pathToGeneratedFile: string, ): Promise { - this._validateCircuitArtifact(circuitArtifact); - const result: GeneratedCircuitWrapperResult[] = []; - const unifiedProtocolType = new Set(circuitArtifact.baseCircuitInfo.protocol); + const unifiedProtocolType = this._getUnifiedProtocolType(circuitArtifact); for (const protocolType of unifiedProtocolType) { const content = await this._genSingleCircuitWrapperClassContent( circuitArtifact, @@ -109,16 +105,6 @@ export default class ZkitTSGenerator extends BaseTSGenerator { return result; } - protected async _genDefaultCircuitWrapperClassContent(circuitArtifact: CircuitArtifact): Promise { - const template = fs.readFileSync(path.join(__dirname, "templates", "default-circuit-wrapper.ts.ejs"), "utf8"); - - const templateParams: DefaultWrapperTemplateParams = { - circuitClassName: this._getCircuitName(circuitArtifact), - }; - - return await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }); - } - private async _genSingleCircuitWrapperClassContent( circuitArtifact: CircuitArtifact, pathToGeneratedFile: string, @@ -175,20 +161,21 @@ export default class ZkitTSGenerator extends BaseTSGenerator { protocolImplementerName: this._getProtocolImplementerName(protocolType), proofTypeInternalName: this._getProofTypeInternalName(protocolType), circuitClassName, - publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"), + publicInputsTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Public"), calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount), publicInputs, privateInputs, calldataPointsType: this._getCalldataPointsType(protocolType), - proofTypeName: this._getTypeName(circuitArtifact, "Proof"), - privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"), + proofTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Proof"), + calldataTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Calldata"), + privateInputsTypeName: this._getTypeName(circuitArtifact, this._getPrefix(protocolType), "Private"), pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils), }; return { content: await prettier.format(ejs.render(template, templateParams), { parser: "typescript" }), className: circuitClassName, - prefix: this._getPrefix(protocolType).toLowerCase(), + protocol: protocolType, }; } @@ -250,9 +237,11 @@ export default class ZkitTSGenerator extends BaseTSGenerator { } } - private _validateCircuitArtifact(circuitArtifact: CircuitArtifact): void { + private _getUnifiedProtocolType(circuitArtifact: CircuitArtifact): Set { if (!circuitArtifact.baseCircuitInfo.protocol) { - throw new Error(`ZKType: Protocol is missing in the circuit artifact: ${circuitArtifact.circuitTemplateName}`); + return new Set(["groth16"]); } + + return new Set(circuitArtifact.baseCircuitInfo.protocol); } } diff --git a/src/core/templates/circuit-wrapper.ts.ejs b/src/core/templates/circuit-wrapper.ts.ejs index fd97eda..6cb86b5 100644 --- a/src/core/templates/circuit-wrapper.ts.ejs +++ b/src/core/templates/circuit-wrapper.ts.ejs @@ -29,7 +29,7 @@ export type <%= proofTypeName %> = { publicSignals: <%= publicInputsTypeName %>; } -export type Calldata = [ +export type <%= calldataTypeName %> = [ <%= calldataPointsType %>, <%= calldataPubSignalsType %>, ]; @@ -59,7 +59,7 @@ export class <%= circuitClassName %> extends CircuitZKit<"<%= protocolTypeName % }); } - public async generateCalldata(proof: <%= proofTypeName %>): Promise { + public async generateCalldata(proof: <%= proofTypeName %>): Promise<<%= calldataTypeName %>> { return super.generateCalldata({ proof: proof.proof, publicSignals: this._denormalizePublicSignals(proof.publicSignals), diff --git a/src/core/templates/default-circuit-wrapper.ts.ejs b/src/core/templates/default-circuit-wrapper.ts.ejs deleted file mode 100644 index 6fdf7e4..0000000 --- a/src/core/templates/default-circuit-wrapper.ts.ejs +++ /dev/null @@ -1,37 +0,0 @@ -import { - Signals, - Calldata, - ProofStruct, - CircuitZKit, - CircuitZKitConfig -} from "@solarity/zkit"; - -export class <%= circuitClassName %> extends CircuitZKit { - constructor(config: CircuitZKitConfig) { - super(config); - } - - public async generateProof(inputs: Signals): Promise { - return super.generateProof(inputs); - } - - public async calculateWitness(inputs: Signals): Promise { - return super.calculateWitness(inputs); - } - - public async verifyProof(proof: ProofStruct): Promise { - return super.verifyProof({ - proof: proof.proof, - publicSignals: proof.publicSignals, - }); - } - - public async generateCalldata(proof: ProofStruct): Promise { - return super.generateCalldata({ - proof: proof.proof, - publicSignals: proof.publicSignals, - }); - } -} - -export default <%= circuitClassName %>; diff --git a/src/types/circuitArtifact.ts b/src/types/circuitArtifact.ts index cc880ad..9fa042d 100644 --- a/src/types/circuitArtifact.ts +++ b/src/types/circuitArtifact.ts @@ -1,5 +1,7 @@ export type FormatTypes = "hh-zkit-artifacts-1"; +export type ProtocolType = "groth16" | "plonk"; + export type SignalType = "Output" | "Input" | "Intermediate"; export type VisibilityType = "Public" | "Private"; diff --git a/src/types/typesGenerator.ts b/src/types/typesGenerator.ts index 43e8cf9..2e2ac56 100644 --- a/src/types/typesGenerator.ts +++ b/src/types/typesGenerator.ts @@ -1,10 +1,14 @@ -import { CircuitArtifact } from "./circuitArtifact"; +import { CircuitArtifact, ProtocolType } from "./circuitArtifact"; import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol"; +export interface CircuitSet { + [circuitName: string]: ArtifactWithPath[]; +} + export interface ArtifactWithPath { circuitArtifact: CircuitArtifact; pathToGeneratedFile: string; - protocol?: string; + protocol: string; } export interface Inputs { @@ -13,12 +17,8 @@ export interface Inputs { dimensionsArray: string; } -export interface DefaultWrapperTemplateParams { - circuitClassName: string; -} - export interface WrapperTemplateParams { - protocolTypeName: "groth16" | "plonk"; + protocolTypeName: ProtocolType; protocolImplementerName: "Groth16Implementer" | "PlonkImplementer"; proofTypeInternalName: "Groth16Proof" | "PlonkProof"; publicInputsTypeName: string; @@ -28,6 +28,7 @@ export interface WrapperTemplateParams { calldataPubSignalsType: string; proofTypeName: string; privateInputsTypeName: string; + calldataTypeName: string; circuitClassName: string; pathToUtils: string; } @@ -45,5 +46,5 @@ export interface TypeExtensionTemplateParams { export interface GeneratedCircuitWrapperResult { content: string; className: string; - prefix: string; + protocol: string; } diff --git a/test/fixture-cache/auth/Matrix_artifacts.json b/test/fixture-cache/auth/Matrix_artifacts.json index 145a49b..a3ccec3 100644 --- a/test/fixture-cache/auth/Matrix_artifacts.json +++ b/test/fixture-cache/auth/Matrix_artifacts.json @@ -5,7 +5,6 @@ "circuitSourceName": "circuits/fixture/auth/Matrix.circom", "baseCircuitInfo": { "constraintsNumber": 8, - "protocol": ["groth16"], "signals": [ { "name": "a", diff --git a/test/helpers/index.ts b/test/helpers/index.ts index 8255191..79bb636 100644 --- a/test/helpers/index.ts +++ b/test/helpers/index.ts @@ -2,9 +2,16 @@ import { CircuitTypesGenerator } from "../../src"; import { findProjectRoot } from "../../src/utils"; const circuitTypesGenerator = new CircuitTypesGenerator({ - basePath: "test/fixture", + basePath: "circuits/fixture", projectRoot: findProjectRoot(process.cwd()), - circuitsArtifactsPaths: ["test/fixture-cache/Multiplier2_artifacts.json"], + circuitsArtifactsPaths: [ + "test/fixture-cache/auth/EnhancedMultiplier_artifacts.json", + "test/fixture-cache/auth/Matrix_artifacts.json", + "test/fixture-cache/auth/Multiplier2_artifacts.json", + "test/fixture-cache/lib/Multiplier2_artifacts.json", + "test/fixture-cache/CredentialAtomicQueryMTPOnChainVoting_artifacts.json", + "test/fixture-cache/Multiplier2_artifacts.json", + ], }); // circuitTypesGenerator.generateTypes().then(console.log).catch(console.error);