diff --git a/src/libspark/test/encrypt_test.cpp b/src/libspark/test/encrypt_test.cpp index d0849b81c8..072bed8281 100644 --- a/src/libspark/test/encrypt_test.cpp +++ b/src/libspark/test/encrypt_test.cpp @@ -25,6 +25,37 @@ BOOST_AUTO_TEST_CASE(complete) BOOST_CHECK_EQUAL(i_, i); } +BOOST_AUTO_TEST_CASE(overflow) +{ + // Number of bytes for our diversifier; this needs to exceed `uint64_t` bounds but not the AES block size + int BYTES = 10; + + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AES256_KEYSIZE); + + // Encrypt a value that will exceed `uint64_t` bounds + // We have to do this manually since the diversifier API won't let us! + std::vector plaintext; + plaintext.resize(BYTES); + for (int i = 0; i < BYTES; i++) { + plaintext[i] = 0xFF; // this will exceed the allowed bounds + } + + std::vector ciphertext; + ciphertext.resize(AES_BLOCKSIZE); + std::vector iv; + iv.resize(AES_BLOCKSIZE); + + AES256CBCEncrypt aes(key.data(), iv.data(), true); + plaintext.resize(AES_BLOCKSIZE); + aes.Encrypt(plaintext.data(), BYTES, ciphertext.data()); + + // Decrypt + BOOST_CHECK_THROW(SparkUtils::diversifier_decrypt(key, ciphertext), std::runtime_error); +} + BOOST_AUTO_TEST_CASE(bad_key) { // Key @@ -41,10 +72,8 @@ BOOST_AUTO_TEST_CASE(bad_key) uint64_t i = 12345; std::vector d = SparkUtils::diversifier_encrypt(key, i); - // Decrypt - uint64_t i_ = SparkUtils::diversifier_decrypt(evil_key, d); - - BOOST_CHECK_NE(i_, i); + // Decryption induces a padding failure, so no plaintext is returned + BOOST_CHECK_THROW(SparkUtils::diversifier_decrypt(evil_key, d), std::runtime_error); } BOOST_AUTO_TEST_SUITE_END() diff --git a/src/libspark/util.cpp b/src/libspark/util.cpp index 4547251320..8f378717d8 100644 --- a/src/libspark/util.cpp +++ b/src/libspark/util.cpp @@ -40,15 +40,18 @@ uint64_t SparkUtils::diversifier_decrypt(const std::vector& key, std::vector iv; iv.resize(AES_BLOCKSIZE); + // Decrypt using padded AES-256 (CBC) using a zero IV, ensuring that the decrypted data is the expected length AES256CBCDecrypt aes(key.data(), iv.data(), true); std::vector plaintext; plaintext.resize(AES_BLOCKSIZE); - aes.Decrypt(d.data(), d.size(), plaintext.data()); + int length = aes.Decrypt(d.data(), d.size(), plaintext.data()); + if (length != sizeof(uint64_t)) { + throw std::runtime_error("Invalid diversifier length"); + } - // Decrypt using padded AES-256 (CBC) using a zero IV + // Deserialize the diversifier CDataStream i_stream(SER_NETWORK, PROTOCOL_VERSION); i_stream.write((const char *)plaintext.data(), sizeof(uint64_t)); - // Deserialize the diversifier uint64_t i; i_stream >> i;