From da2c96b2cb84cdfb852aa9a40b07e524182cbd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph-Andr=C3=A9=20Turk?= Date: Tue, 17 Sep 2024 13:30:19 +0200 Subject: [PATCH] feat: support for trustless async decrypt --- examples/TestAsyncDecrypt.sol | 62 +++++++++++++++++++++++++ gateway/GatewayContract.sol | 14 +++++- gateway/lib/Gateway.sol | 39 +++++++++++++++- test/asyncDecrypt.ts | 13 ++++-- test/gatewayDecrypt/testAsyncDecrypt.ts | 62 +++++++++++++++++++++++++ test/kmsVerifier/kmsVerifier.ts | 28 +++++++++++ 6 files changed, 212 insertions(+), 6 deletions(-) diff --git a/examples/TestAsyncDecrypt.sol b/examples/TestAsyncDecrypt.sol index df3ba797..a305dc7a 100644 --- a/examples/TestAsyncDecrypt.sol +++ b/examples/TestAsyncDecrypt.sol @@ -345,4 +345,66 @@ contract TestAsyncDecrypt is GatewayCaller { yAddress = decAddress; yBytes256 = bytesRes; } + + function requestEbytes256NonTrivialTrustless(einput inputHandle, bytes calldata inputProof) public { + ebytes256 inputNonTrivial = TFHE.asEbytes256(inputHandle, inputProof); + uint256[] memory cts = new uint256[](1); + cts[0] = Gateway.toUint256(inputNonTrivial); + uint256 requestID = Gateway.requestDecryption( + cts, + this.callbackBytes256Trustless.selector, + 0, + block.timestamp + 100, + true + ); + latestRequestID = requestID; + saveRequestedHandles(requestID, cts); + } + + function callbackBytes256Trustless( + uint256 requestID, + bytes calldata decryptedInput, + bytes[] memory signatures + ) public onlyGateway returns (bytes memory) { + require(latestRequestID == requestID, "wrong requestID passed by Gateway"); + uint256[] memory requestedHandles = loadRequestedHandles(latestRequestID); + bool isKMSVerified = Gateway.verifySignatures(requestedHandles, signatures); + require(isKMSVerified, "KMS did not verify this decryption result"); + yBytes256 = decryptedInput; + return decryptedInput; + } + + function requestMixedBytes256Trustless(einput inputHandle, bytes calldata inputProof) public { + ebytes256 xBytes256 = TFHE.asEbytes256(inputHandle, inputProof); + uint256[] memory cts = new uint256[](3); + cts[0] = Gateway.toUint256(xBool); + cts[1] = Gateway.toUint256(xBytes256); + cts[2] = Gateway.toUint256(xAddress); + Gateway.requestDecryption(cts, this.callbackMixedBytes256Trustless.selector, 0, block.timestamp + 100, true); + uint256 requestID = Gateway.requestDecryption( + cts, + this.callbackMixedBytes256Trustless.selector, + 0, + block.timestamp + 100, + true + ); + latestRequestID = requestID; + saveRequestedHandles(requestID, cts); + } + + function callbackMixedBytes256Trustless( + uint256 requestID, + bool decBool, + bytes memory bytesRes, + address decAddress, + bytes[] memory signatures + ) public onlyGateway { + require(latestRequestID == requestID, "wrong requestID passed by Gateway"); + uint256[] memory requestedHandles = loadRequestedHandles(latestRequestID); + bool isKMSVerified = Gateway.verifySignatures(requestedHandles, signatures); + require(isKMSVerified, "KMS did not verify this decryption result"); + yBool = decBool; + yAddress = decAddress; + yBytes256 = bytesRes; + } } diff --git a/gateway/GatewayContract.sol b/gateway/GatewayContract.sol index 64fb0747..0a62301e 100644 --- a/gateway/GatewayContract.sol +++ b/gateway/GatewayContract.sol @@ -186,8 +186,11 @@ contract GatewayContract is UUPSUpgradeable, Ownable2StepUpgradeable { bool passSignatures = decryptionReq.passSignaturesToCaller; callbackCalldata = abi.encodePacked(callbackCalldata, decryptedCts); // decryptedCts MUST be correctly abi-encoded by the relayer, according to the requested types of `ctsHandles` if (passSignatures) { - callbackCalldata = abi.encodePacked(callbackCalldata, abi.encode(signatures)); + bytes memory packedSignatures = abi.encode(signatures); + bytes memory packedSignaturesNoOffset = removeOffset(packedSignatures); // remove the offset (the first 32 bytes) before concatenating with the first part of calldata + callbackCalldata = abi.encodePacked(callbackCalldata, packedSignaturesNoOffset); } + (bool success, bytes memory result) = (decryptionReq.contractCaller).call{value: decryptionReq.msgValue}( callbackCalldata ); @@ -195,6 +198,15 @@ contract GatewayContract is UUPSUpgradeable, Ownable2StepUpgradeable { $.isFulfilled[requestID] = true; } + function removeOffset(bytes memory input) public pure virtual returns (bytes memory) { + uint256 newLength = input.length - 32; + bytes memory result = new bytes(newLength); + for (uint256 i = 0; i < newLength; i++) { + result[i] = input[i + 32]; + } + return result; + } + /// @notice Getter for the name and version of the contract /// @return string representing the name and the version of the contract function getVersion() external pure virtual returns (string memory) { diff --git a/gateway/lib/Gateway.sol b/gateway/lib/Gateway.sol index eb45ad94..e15e212c 100644 --- a/gateway/lib/Gateway.sol +++ b/gateway/lib/Gateway.sol @@ -95,13 +95,24 @@ library Gateway { ); } + /*function verifySignatures(uint256[] memory handlesList, bytes[] memory signatures) internal returns (bool) { + uint256 start = 4 + 32; // start position after skipping the selector (4 bytes) and the first argument (index, 32 bytes) + uint256 numArgs = handlesList.length; // Number of arguments before signatures + uint256 length = numArgs * 32 + 32; // TODO: fix the way we compute length in case the type of the handle is an ebytes256 (loop over all handles and add correct length corresponding to each type) + bytes memory decryptedResult = new bytes(length); + assembly { + calldatacopy(add(decryptedResult, 0x20), start, length) // Copy the relevant part of calldata to decryptedResult memory + } + FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig(); + return IKMSVerifier($.KMSVerifierAddress).verifySignatures(handlesList, decryptedResult, signatures); + }*/ + /// @dev this function is supposed to be called inside the callback function if the dev wants the dApp contract to verify the signatures /// @dev this is useful to give dev the choice not to rely on trusting the GatewayContract. /// @notice this could be used only when signatures are made available to the callback, i.e when `passSignaturesToCaller` is set to true during request function verifySignatures(uint256[] memory handlesList, bytes[] memory signatures) internal returns (bool) { uint256 start = 4 + 32; // start position after skipping the selector (4 bytes) and the first argument (index, 32 bytes) - uint256 numArgs = handlesList.length; // Number of arguments before signatures - uint256 length = numArgs * 32; // TODO: fix the way we compute length in case the type of the handle is an ebytes256 (loop over all handles and add correct length corresponding to each type) + uint256 length = getSignedDataLength(handlesList); bytes memory decryptedResult = new bytes(length); assembly { calldatacopy(add(decryptedResult, 0x20), start, length) // Copy the relevant part of calldata to decryptedResult memory @@ -109,4 +120,28 @@ library Gateway { FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig(); return IKMSVerifier($.KMSVerifierAddress).verifySignatures(handlesList, decryptedResult, signatures); } + + function getSignedDataLength(uint256[] memory handlesList) private pure returns (uint256) { + uint256 handlesListlen = handlesList.length; + uint256 signedDataLength; + for (uint256 i = 0; i < handlesListlen; i++) { + uint8 typeCt = uint8(handlesList[i] >> 8); + if (typeCt < 9) { + signedDataLength += 32; + } else if (typeCt == 9) { + //ebytes64 + signedDataLength += 128; + } else if (typeCt == 10) { + //ebytes128 + signedDataLength += 192; + } else if (typeCt == 11) { + //ebytes256 + signedDataLength += 320; + } else { + revert("Unsupported handle type"); + } + } + signedDataLength += 32; // for the signatures offset + return signedDataLength; + } } diff --git a/test/asyncDecrypt.ts b/test/asyncDecrypt.ts index fd28b2e9..43c2cfa8 100644 --- a/test/asyncDecrypt.ts +++ b/test/asyncDecrypt.ts @@ -116,6 +116,7 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => { const handles = event.args[1]; const typesList = handles.map((handle) => parseInt(handle.toString(16).slice(-4, -2), 16)); const msgValue = event.args[4]; + const passSignaturesToCaller = event.args[6]; if (!results.includes(requestID)) { // if request is not already fulfilled if (mocked) { @@ -140,12 +141,18 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => { ); const abiCoder = new ethers.AbiCoder(); - const encodedData = abiCoder.encode(['uint256', ...types], [31, ...valuesFormatted2]); // 31 is just a dummy uint256 requestID to get correct abi encoding for the remaining arguments (i.e everything except the requestID) - const calldata = '0x' + encodedData.slice(66); // we just pop the dummy requestID to get the correct value to pass for `decryptedCts` + let encodedData; + let calldata; + if (!passSignaturesToCaller) { + encodedData = abiCoder.encode(['uint256', ...types], [31, ...valuesFormatted2]); // 31 is just a dummy uint256 requestID to get correct abi encoding for the remaining arguments (i.e everything except the requestID) + calldata = '0x' + encodedData.slice(66); // we just pop the dummy requestID to get the correct value to pass for `decryptedCts` + } else { + encodedData = abiCoder.encode(['uint256', ...types, 'bytes[]'], [31, ...valuesFormatted2, []]); // adding also a dummy empty array of bytes for correct abi-encoding when used with signatures + calldata = '0x' + encodedData.slice(66).slice(0, -64); // we also pop the last 32 bytes (empty bytes[]) + } const numSigners = +process.env.NUM_KMS_SIGNERS!; const decryptResultsEIP712signatures = await computeDecryptSignatures(handles, calldata, numSigners); - const tx = await gateway .connect(relayer) .fulfillRequest(requestID, calldata, decryptResultsEIP712signatures, { value: msgValue }); diff --git a/test/gatewayDecrypt/testAsyncDecrypt.ts b/test/gatewayDecrypt/testAsyncDecrypt.ts index 3b511ee9..1d496adc 100644 --- a/test/gatewayDecrypt/testAsyncDecrypt.ts +++ b/test/gatewayDecrypt/testAsyncDecrypt.ts @@ -87,6 +87,18 @@ describe('TestAsyncDecrypt', function () { console.log('gas paid by user (request tx) : ', balanceBeforeU - balanceAfterU); }); + it('test async decrypt bool trustless', async function () { + const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt'); + const contract2 = await contractFactory.connect(this.signers.alice).deploy({ + value: ethers.parseEther('0.001'), + }); + const tx2 = await contract2.requestBoolTrustless({ gasLimit: 5_000_000 }); + await tx2.wait(); + await awaitAllDecryptionResults(); + const y = await contract2.yBool(); + expect(y).to.equal(true); + }); + it.skip('test async decrypt FAKE bool', async function () { if (network.name !== 'hardhat') { // only in fhevm mode @@ -377,4 +389,54 @@ describe('TestAsyncDecrypt', function () { const yAdd = await this.contract.yAddress(); expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72'); }); + + it('test async decrypt ebytes256 non-trivial trustless', async function () { + const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt'); + const contract2 = await contractFactory.connect(this.signers.alice).deploy({ + value: ethers.parseEther('0.001'), + }); + const inputAlice = this.instances.alice.createEncryptedInput( + await contract2.getAddress(), + this.signers.alice.address, + ); + inputAlice.addBytes256(bigIntToBytes(18446744073709550022n)); + const encryptedAmount = inputAlice.encrypt(); + const tx = await contract2.requestEbytes256NonTrivialTrustless( + encryptedAmount.handles[0], + encryptedAmount.inputProof, + { gasLimit: 5_000_000 }, + ); + await tx.wait(); + await awaitAllDecryptionResults(); + const y = await contract2.yBytes256(); + expect(y).to.equal(ethers.toBeHex(18446744073709550022n, 256)); + }); + + it('test async decrypt mixed with ebytes256 trustless', async function () { + const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt'); + const contract2 = await contractFactory.connect(this.signers.alice).deploy({ + value: ethers.parseEther('0.001'), + }); + const inputAlice = this.instances.alice.createEncryptedInput( + await contract2.getAddress(), + this.signers.alice.address, + ); + inputAlice.addBytes256(bigIntToBytes(18446744073709550032n)); + const encryptedAmount = inputAlice.encrypt(); + const tx = await await contract2.requestMixedBytes256Trustless( + encryptedAmount.handles[0], + encryptedAmount.inputProof, + { + gasLimit: 5_000_000, + }, + ); + await tx.wait(); + await awaitAllDecryptionResults(); + const y = await contract2.yBytes256(); + expect(y).to.equal(ethers.toBeHex(18446744073709550032n, 256)); + const yb = await contract2.yBool(); + expect(yb).to.equal(true); + const yAdd = await contract2.yAddress(); + expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72'); + }); }); diff --git a/test/kmsVerifier/kmsVerifier.ts b/test/kmsVerifier/kmsVerifier.ts index d55b9c40..61c63535 100644 --- a/test/kmsVerifier/kmsVerifier.ts +++ b/test/kmsVerifier/kmsVerifier.ts @@ -4,12 +4,15 @@ import fs from 'fs'; import { ethers } from 'hardhat'; import { asyncDecrypt, awaitAllDecryptionResults } from '../asyncDecrypt'; +import { createInstances } from '../instance'; import { getSigners, initSigners } from '../signers'; +import { bigIntToBytes } from '../utils'; describe('KMSVerifier', function () { before(async function () { await initSigners(2); this.signers = await getSigners(); + this.instances = await createInstances(this.signers); this.kmsFactory = await ethers.getContractFactory('KMSVerifier'); await asyncDecrypt(); }); @@ -76,6 +79,31 @@ describe('KMSVerifier', function () { const y4 = await contract.yUint8(); expect(y4).to.equal(42); // even with more than 2 signatures decryption should still succeed + const contract2 = await contractFactory.connect(this.signers.alice).deploy({ + value: ethers.parseEther('0.001'), + }); + const inputAlice = this.instances.alice.createEncryptedInput( + await contract2.getAddress(), + this.signers.alice.address, + ); + inputAlice.addBytes256(bigIntToBytes(18446744073709550032n)); + const encryptedAmount = inputAlice.encrypt(); + const tx6bis = await await contract2.requestMixedBytes256Trustless( + encryptedAmount.handles[0], + encryptedAmount.inputProof, + { + gasLimit: 5_000_000, + }, + ); + await tx6bis.wait(); + await awaitAllDecryptionResults(); + const ybis = await contract2.yBytes256(); + expect(ybis).to.equal(ethers.toBeHex(18446744073709550032n, 256)); + const yb = await contract2.yBool(); + expect(yb).to.equal(true); + const yAdd = await contract2.yAddress(); + expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72'); // testing trustless mixed with ebytes256, in case of several signatures + process.env.NUM_KMS_SIGNERS = '2'; process.env.PRIVATE_KEY_KMS_SIGNER_1 = process.env.PRIVATE_KEY_KMS_SIGNER_0; const tx7 = await contract.requestUint16({ gasLimit: 5_000_000 });