Skip to content

Commit

Permalink
feat: support for trustless async decrypt
Browse files Browse the repository at this point in the history
  • Loading branch information
jatZama committed Sep 17, 2024
1 parent c1a8df6 commit da2c96b
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 6 deletions.
62 changes: 62 additions & 0 deletions examples/TestAsyncDecrypt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,66 @@ contract TestAsyncDecrypt is GatewayCaller {
yAddress = decAddress;
yBytes256 = bytesRes;
}

function requestEbytes256NonTrivialTrustless(einput inputHandle, bytes calldata inputProof) public {
ebytes256 inputNonTrivial = TFHE.asEbytes256(inputHandle, inputProof);
uint256[] memory cts = new uint256[](1);
cts[0] = Gateway.toUint256(inputNonTrivial);
uint256 requestID = Gateway.requestDecryption(
cts,
this.callbackBytes256Trustless.selector,
0,
block.timestamp + 100,
true
);
latestRequestID = requestID;
saveRequestedHandles(requestID, cts);
}

function callbackBytes256Trustless(
uint256 requestID,
bytes calldata decryptedInput,
bytes[] memory signatures
) public onlyGateway returns (bytes memory) {
require(latestRequestID == requestID, "wrong requestID passed by Gateway");
uint256[] memory requestedHandles = loadRequestedHandles(latestRequestID);
bool isKMSVerified = Gateway.verifySignatures(requestedHandles, signatures);
require(isKMSVerified, "KMS did not verify this decryption result");
yBytes256 = decryptedInput;
return decryptedInput;
}

function requestMixedBytes256Trustless(einput inputHandle, bytes calldata inputProof) public {
ebytes256 xBytes256 = TFHE.asEbytes256(inputHandle, inputProof);
uint256[] memory cts = new uint256[](3);
cts[0] = Gateway.toUint256(xBool);
cts[1] = Gateway.toUint256(xBytes256);
cts[2] = Gateway.toUint256(xAddress);
Gateway.requestDecryption(cts, this.callbackMixedBytes256Trustless.selector, 0, block.timestamp + 100, true);
uint256 requestID = Gateway.requestDecryption(
cts,
this.callbackMixedBytes256Trustless.selector,
0,
block.timestamp + 100,
true
);
latestRequestID = requestID;
saveRequestedHandles(requestID, cts);
}

function callbackMixedBytes256Trustless(
uint256 requestID,
bool decBool,
bytes memory bytesRes,
address decAddress,
bytes[] memory signatures
) public onlyGateway {
require(latestRequestID == requestID, "wrong requestID passed by Gateway");
uint256[] memory requestedHandles = loadRequestedHandles(latestRequestID);
bool isKMSVerified = Gateway.verifySignatures(requestedHandles, signatures);
require(isKMSVerified, "KMS did not verify this decryption result");
yBool = decBool;
yAddress = decAddress;
yBytes256 = bytesRes;
}
}
14 changes: 13 additions & 1 deletion gateway/GatewayContract.sol
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,27 @@ contract GatewayContract is UUPSUpgradeable, Ownable2StepUpgradeable {
bool passSignatures = decryptionReq.passSignaturesToCaller;
callbackCalldata = abi.encodePacked(callbackCalldata, decryptedCts); // decryptedCts MUST be correctly abi-encoded by the relayer, according to the requested types of `ctsHandles`
if (passSignatures) {
callbackCalldata = abi.encodePacked(callbackCalldata, abi.encode(signatures));
bytes memory packedSignatures = abi.encode(signatures);
bytes memory packedSignaturesNoOffset = removeOffset(packedSignatures); // remove the offset (the first 32 bytes) before concatenating with the first part of calldata
callbackCalldata = abi.encodePacked(callbackCalldata, packedSignaturesNoOffset);
}

(bool success, bytes memory result) = (decryptionReq.contractCaller).call{value: decryptionReq.msgValue}(
callbackCalldata
);
emit ResultCallback(requestID, success, result);
$.isFulfilled[requestID] = true;
}

function removeOffset(bytes memory input) public pure virtual returns (bytes memory) {
uint256 newLength = input.length - 32;
bytes memory result = new bytes(newLength);
for (uint256 i = 0; i < newLength; i++) {
result[i] = input[i + 32];
}
return result;
}

/// @notice Getter for the name and version of the contract
/// @return string representing the name and the version of the contract
function getVersion() external pure virtual returns (string memory) {
Expand Down
39 changes: 37 additions & 2 deletions gateway/lib/Gateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,53 @@ library Gateway {
);
}

/*function verifySignatures(uint256[] memory handlesList, bytes[] memory signatures) internal returns (bool) {
uint256 start = 4 + 32; // start position after skipping the selector (4 bytes) and the first argument (index, 32 bytes)
uint256 numArgs = handlesList.length; // Number of arguments before signatures
uint256 length = numArgs * 32 + 32; // TODO: fix the way we compute length in case the type of the handle is an ebytes256 (loop over all handles and add correct length corresponding to each type)
bytes memory decryptedResult = new bytes(length);
assembly {
calldatacopy(add(decryptedResult, 0x20), start, length) // Copy the relevant part of calldata to decryptedResult memory
}
FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig();
return IKMSVerifier($.KMSVerifierAddress).verifySignatures(handlesList, decryptedResult, signatures);
}*/

/// @dev this function is supposed to be called inside the callback function if the dev wants the dApp contract to verify the signatures
/// @dev this is useful to give dev the choice not to rely on trusting the GatewayContract.
/// @notice this could be used only when signatures are made available to the callback, i.e when `passSignaturesToCaller` is set to true during request
function verifySignatures(uint256[] memory handlesList, bytes[] memory signatures) internal returns (bool) {
uint256 start = 4 + 32; // start position after skipping the selector (4 bytes) and the first argument (index, 32 bytes)
uint256 numArgs = handlesList.length; // Number of arguments before signatures
uint256 length = numArgs * 32; // TODO: fix the way we compute length in case the type of the handle is an ebytes256 (loop over all handles and add correct length corresponding to each type)
uint256 length = getSignedDataLength(handlesList);
bytes memory decryptedResult = new bytes(length);
assembly {
calldatacopy(add(decryptedResult, 0x20), start, length) // Copy the relevant part of calldata to decryptedResult memory
}
FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig();
return IKMSVerifier($.KMSVerifierAddress).verifySignatures(handlesList, decryptedResult, signatures);
}

function getSignedDataLength(uint256[] memory handlesList) private pure returns (uint256) {
uint256 handlesListlen = handlesList.length;
uint256 signedDataLength;
for (uint256 i = 0; i < handlesListlen; i++) {
uint8 typeCt = uint8(handlesList[i] >> 8);
if (typeCt < 9) {
signedDataLength += 32;
} else if (typeCt == 9) {
//ebytes64
signedDataLength += 128;
} else if (typeCt == 10) {
//ebytes128
signedDataLength += 192;
} else if (typeCt == 11) {
//ebytes256
signedDataLength += 320;
} else {
revert("Unsupported handle type");
}
}
signedDataLength += 32; // for the signatures offset
return signedDataLength;
}
}
13 changes: 10 additions & 3 deletions test/asyncDecrypt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => {
const handles = event.args[1];
const typesList = handles.map((handle) => parseInt(handle.toString(16).slice(-4, -2), 16));
const msgValue = event.args[4];
const passSignaturesToCaller = event.args[6];
if (!results.includes(requestID)) {
// if request is not already fulfilled
if (mocked) {
Expand All @@ -140,12 +141,18 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => {
);

const abiCoder = new ethers.AbiCoder();
const encodedData = abiCoder.encode(['uint256', ...types], [31, ...valuesFormatted2]); // 31 is just a dummy uint256 requestID to get correct abi encoding for the remaining arguments (i.e everything except the requestID)
const calldata = '0x' + encodedData.slice(66); // we just pop the dummy requestID to get the correct value to pass for `decryptedCts`
let encodedData;
let calldata;
if (!passSignaturesToCaller) {
encodedData = abiCoder.encode(['uint256', ...types], [31, ...valuesFormatted2]); // 31 is just a dummy uint256 requestID to get correct abi encoding for the remaining arguments (i.e everything except the requestID)
calldata = '0x' + encodedData.slice(66); // we just pop the dummy requestID to get the correct value to pass for `decryptedCts`
} else {
encodedData = abiCoder.encode(['uint256', ...types, 'bytes[]'], [31, ...valuesFormatted2, []]); // adding also a dummy empty array of bytes for correct abi-encoding when used with signatures
calldata = '0x' + encodedData.slice(66).slice(0, -64); // we also pop the last 32 bytes (empty bytes[])
}

const numSigners = +process.env.NUM_KMS_SIGNERS!;
const decryptResultsEIP712signatures = await computeDecryptSignatures(handles, calldata, numSigners);

const tx = await gateway
.connect(relayer)
.fulfillRequest(requestID, calldata, decryptResultsEIP712signatures, { value: msgValue });
Expand Down
62 changes: 62 additions & 0 deletions test/gatewayDecrypt/testAsyncDecrypt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ describe('TestAsyncDecrypt', function () {
console.log('gas paid by user (request tx) : ', balanceBeforeU - balanceAfterU);
});

it('test async decrypt bool trustless', async function () {
const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt');
const contract2 = await contractFactory.connect(this.signers.alice).deploy({
value: ethers.parseEther('0.001'),
});
const tx2 = await contract2.requestBoolTrustless({ gasLimit: 5_000_000 });
await tx2.wait();
await awaitAllDecryptionResults();
const y = await contract2.yBool();
expect(y).to.equal(true);
});

it.skip('test async decrypt FAKE bool', async function () {
if (network.name !== 'hardhat') {
// only in fhevm mode
Expand Down Expand Up @@ -377,4 +389,54 @@ describe('TestAsyncDecrypt', function () {
const yAdd = await this.contract.yAddress();
expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72');
});

it('test async decrypt ebytes256 non-trivial trustless', async function () {
const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt');
const contract2 = await contractFactory.connect(this.signers.alice).deploy({
value: ethers.parseEther('0.001'),
});
const inputAlice = this.instances.alice.createEncryptedInput(
await contract2.getAddress(),
this.signers.alice.address,
);
inputAlice.addBytes256(bigIntToBytes(18446744073709550022n));
const encryptedAmount = inputAlice.encrypt();
const tx = await contract2.requestEbytes256NonTrivialTrustless(
encryptedAmount.handles[0],
encryptedAmount.inputProof,
{ gasLimit: 5_000_000 },
);
await tx.wait();
await awaitAllDecryptionResults();
const y = await contract2.yBytes256();
expect(y).to.equal(ethers.toBeHex(18446744073709550022n, 256));
});

it('test async decrypt mixed with ebytes256 trustless', async function () {
const contractFactory = await ethers.getContractFactory('TestAsyncDecrypt');
const contract2 = await contractFactory.connect(this.signers.alice).deploy({
value: ethers.parseEther('0.001'),
});
const inputAlice = this.instances.alice.createEncryptedInput(
await contract2.getAddress(),
this.signers.alice.address,
);
inputAlice.addBytes256(bigIntToBytes(18446744073709550032n));
const encryptedAmount = inputAlice.encrypt();
const tx = await await contract2.requestMixedBytes256Trustless(
encryptedAmount.handles[0],
encryptedAmount.inputProof,
{
gasLimit: 5_000_000,
},
);
await tx.wait();
await awaitAllDecryptionResults();
const y = await contract2.yBytes256();
expect(y).to.equal(ethers.toBeHex(18446744073709550032n, 256));
const yb = await contract2.yBool();
expect(yb).to.equal(true);
const yAdd = await contract2.yAddress();
expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72');
});
});
28 changes: 28 additions & 0 deletions test/kmsVerifier/kmsVerifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import fs from 'fs';
import { ethers } from 'hardhat';

import { asyncDecrypt, awaitAllDecryptionResults } from '../asyncDecrypt';
import { createInstances } from '../instance';
import { getSigners, initSigners } from '../signers';
import { bigIntToBytes } from '../utils';

describe('KMSVerifier', function () {
before(async function () {
await initSigners(2);
this.signers = await getSigners();
this.instances = await createInstances(this.signers);
this.kmsFactory = await ethers.getContractFactory('KMSVerifier');
await asyncDecrypt();
});
Expand Down Expand Up @@ -76,6 +79,31 @@ describe('KMSVerifier', function () {
const y4 = await contract.yUint8();
expect(y4).to.equal(42); // even with more than 2 signatures decryption should still succeed

const contract2 = await contractFactory.connect(this.signers.alice).deploy({
value: ethers.parseEther('0.001'),
});
const inputAlice = this.instances.alice.createEncryptedInput(
await contract2.getAddress(),
this.signers.alice.address,
);
inputAlice.addBytes256(bigIntToBytes(18446744073709550032n));
const encryptedAmount = inputAlice.encrypt();
const tx6bis = await await contract2.requestMixedBytes256Trustless(
encryptedAmount.handles[0],
encryptedAmount.inputProof,
{
gasLimit: 5_000_000,
},
);
await tx6bis.wait();
await awaitAllDecryptionResults();
const ybis = await contract2.yBytes256();
expect(ybis).to.equal(ethers.toBeHex(18446744073709550032n, 256));
const yb = await contract2.yBool();
expect(yb).to.equal(true);
const yAdd = await contract2.yAddress();
expect(yAdd).to.equal('0x8ba1f109551bD432803012645Ac136ddd64DBA72'); // testing trustless mixed with ebytes256, in case of several signatures

process.env.NUM_KMS_SIGNERS = '2';
process.env.PRIVATE_KEY_KMS_SIGNER_1 = process.env.PRIVATE_KEY_KMS_SIGNER_0;
const tx7 = await contract.requestUint16({ gasLimit: 5_000_000 });
Expand Down

0 comments on commit da2c96b

Please sign in to comment.