From 6e438e64eac19af0a1c66ceb809d4ae2d18d900f Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Fri, 12 Jul 2024 03:50:53 +1200 Subject: [PATCH] GH-43142: [C++][Parquet] Refactor Encryptor API to use arrow::util::span instead of raw pointers (#43195) ### Rationale for this change See #43142. This is a follow up to #43071 which refactored the Decryptor API and added extra checks to prevent segfaults. This PR makes similar changes to the Encryptor API for consistency and better maintainability. ### What changes are included in this PR? * Change `AesEncryptor::Encrypt` and `Encryptor::Encrypt` to use `arrow::util::span` instead of raw pointers * Replace the `AesEncryptor::CiphertextSizeDelta` method with a `CiphertextLength` method that checks for overflow and abstracts the size difference behaviour away from consumer code for improved readability. ### Are these changes tested? * This is mostly a refactoring of existing code so is covered by existing tests. ### Are there any user-facing changes? No * GitHub Issue: #43142 Lead-authored-by: Adam Reeve Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/parquet/column_writer.cc | 15 +- .../parquet/encryption/encryption_internal.cc | 231 +++++++++++------- .../parquet/encryption/encryption_internal.h | 18 +- .../encryption/encryption_internal_nossl.cc | 18 +- .../encryption/encryption_internal_test.cc | 20 +- .../encryption/internal_file_encryptor.cc | 11 +- .../encryption/internal_file_encryptor.h | 6 +- .../encryption/key_toolkit_internal.cc | 15 +- cpp/src/parquet/metadata.cc | 34 +-- cpp/src/parquet/thrift_internal.h | 14 +- 10 files changed, 218 insertions(+), 164 deletions(-) diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index ac1c3ea2e3e20..c9f6e482981c0 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -303,9 +303,10 @@ class SerializedPageWriter : public PageWriter { if (data_encryptor_.get()) { UpdateEncryption(encryption::kDictionaryPage); PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( - data_encryptor_->CiphertextSizeDelta() + output_data_len, false)); - output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len, - encryption_buffer_->mutable_data()); + data_encryptor_->CiphertextLength(output_data_len), false)); + output_data_len = + data_encryptor_->Encrypt(compressed_data->span_as(), + encryption_buffer_->mutable_span_as()); output_data_buffer = encryption_buffer_->data(); } @@ -395,11 +396,11 @@ class SerializedPageWriter : public PageWriter { if (data_encryptor_.get()) { PARQUET_THROW_NOT_OK(encryption_buffer_->Resize( - data_encryptor_->CiphertextSizeDelta() + output_data_len, false)); + data_encryptor_->CiphertextLength(output_data_len), false)); UpdateEncryption(encryption::kDataPage); - output_data_len = data_encryptor_->Encrypt(compressed_data->data(), - static_cast(output_data_len), - encryption_buffer_->mutable_data()); + output_data_len = + data_encryptor_->Encrypt(compressed_data->span_as(), + encryption_buffer_->mutable_span_as()); output_data_buffer = encryption_buffer_->data(); } diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index c5d2d1728ba1e..6168dd2a9bd61 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -58,12 +58,12 @@ class AesEncryptor::AesEncryptorImpl { ~AesEncryptorImpl() { WipeOut(); } - int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext); + int Encrypt(span plaintext, span key, + span aad, span ciphertext); - int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, - const uint8_t* nonce, uint8_t* encrypted_footer); + int SignedFooterEncrypt(span footer, span key, + span aad, span nonce, + span encrypted_footer); void WipeOut() { if (nullptr != ctx_) { EVP_CIPHER_CTX_free(ctx_); @@ -71,7 +71,21 @@ class AesEncryptor::AesEncryptorImpl { } } - int ciphertext_size_delta() { return ciphertext_size_delta_; } + [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const { + if (plaintext_len < 0) { + std::stringstream ss; + ss << "Negative plaintext length " << plaintext_len; + throw ParquetException(ss.str()); + } else if (plaintext_len > + std::numeric_limits::max() - ciphertext_size_delta_) { + std::stringstream ss; + ss << "Plaintext length " << plaintext_len << " plus ciphertext size delta " + << ciphertext_size_delta_ << " overflows int32"; + throw ParquetException(ss.str()); + } + + return static_cast(plaintext_len + ciphertext_size_delta_); + } private: EVP_CIPHER_CTX* ctx_; @@ -80,12 +94,12 @@ class AesEncryptor::AesEncryptorImpl { int ciphertext_size_delta_; int length_buffer_length_; - int GcmEncrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* nonce, const uint8_t* aad, int aad_len, - uint8_t* ciphertext); + int GcmEncrypt(span plaintext, span key, + span nonce, span aad, + span ciphertext); - int CtrEncrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* nonce, uint8_t* ciphertext); + int CtrEncrypt(span plaintext, span key, + span nonce, span ciphertext); }; AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, @@ -137,12 +151,21 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int } } -int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( - const uint8_t* footer, int footer_len, const uint8_t* key, int key_len, - const uint8_t* aad, int aad_len, const uint8_t* nonce, uint8_t* encrypted_footer) { - if (key_length_ != key_len) { +int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(span footer, + span key, + span aad, + span nonce, + span encrypted_footer) { + if (static_cast(key_length_) != key.size()) { + std::stringstream ss; + ss << "Wrong key length " << key.size() << ". Should be " << key_length_; + throw ParquetException(ss.str()); + } + + if (encrypted_footer.size() != footer.size() + ciphertext_size_delta_) { std::stringstream ss; - ss << "Wrong key length " << key_len << ". Should be " << key_length_; + ss << "Encrypted footer buffer length " << encrypted_footer.size() + << " does not match expected length " << (footer.size() + ciphertext_size_delta_); throw ParquetException(ss.str()); } @@ -150,72 +173,85 @@ int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( throw ParquetException("Must use AES GCM (metadata) encryptor"); } - return GcmEncrypt(footer, footer_len, key, key_len, nonce, aad, aad_len, - encrypted_footer); + return GcmEncrypt(footer, key, nonce, aad, encrypted_footer); } -int AesEncryptor::AesEncryptorImpl::Encrypt(const uint8_t* plaintext, int plaintext_len, - const uint8_t* key, int key_len, - const uint8_t* aad, int aad_len, - uint8_t* ciphertext) { - if (key_length_ != key_len) { +int AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, + span key, + span aad, + span ciphertext) { + if (static_cast(key_length_) != key.size()) { std::stringstream ss; - ss << "Wrong key length " << key_len << ". Should be " << key_length_; + ss << "Wrong key length " << key.size() << ". Should be " << key_length_; throw ParquetException(ss.str()); } - uint8_t nonce[kNonceLength]; - memset(nonce, 0, kNonceLength); + if (ciphertext.size() != plaintext.size() + ciphertext_size_delta_) { + std::stringstream ss; + ss << "Ciphertext buffer length " << ciphertext.size() + << " does not match expected length " + << (plaintext.size() + ciphertext_size_delta_); + throw ParquetException(ss.str()); + } + + std::array nonce{}; // Random nonce - RAND_bytes(nonce, sizeof(nonce)); + RAND_bytes(nonce.data(), kNonceLength); if (kGcmMode == aes_mode_) { - return GcmEncrypt(plaintext, plaintext_len, key, key_len, nonce, aad, aad_len, - ciphertext); + return GcmEncrypt(plaintext, key, nonce, aad, ciphertext); } - return CtrEncrypt(plaintext, plaintext_len, key, key_len, nonce, ciphertext); + return CtrEncrypt(plaintext, key, nonce, ciphertext); } -int AesEncryptor::AesEncryptorImpl::GcmEncrypt(const uint8_t* plaintext, - int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* nonce, - const uint8_t* aad, int aad_len, - uint8_t* ciphertext) { +int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, + span key, + span nonce, + span aad, + span ciphertext) { int len; int ciphertext_len; - uint8_t tag[kGcmTagLength]; - memset(tag, 0, kGcmTagLength); + std::array tag{}; + + if (nonce.size() != static_cast(kNonceLength)) { + std::stringstream ss; + ss << "Invalid nonce size " << nonce.size() << ", expected " << kNonceLength; + throw ParquetException(ss.str()); + } // Setting key and IV (nonce) - if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key, nonce)) { + if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key.data(), nonce.data())) { throw ParquetException("Couldn't set key and nonce"); } // Setting additional authenticated data - if ((nullptr != aad) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad, aad_len))) { + if ((!aad.empty()) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad.data(), + static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } // Encryption - if (1 != EVP_EncryptUpdate(ctx_, ciphertext + length_buffer_length_ + kNonceLength, - &len, plaintext, plaintext_len)) { + if (1 != + EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, + &len, plaintext.data(), static_cast(plaintext.size()))) { throw ParquetException("Failed encryption update"); } ciphertext_len = len; // Finalization - if (1 != EVP_EncryptFinal_ex( - ctx_, ciphertext + length_buffer_length_ + kNonceLength + len, &len)) { + if (1 != + EVP_EncryptFinal_ex( + ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength + len, &len)) { throw ParquetException("Failed encryption finalization"); } ciphertext_len += len; // Getting the tag - if (1 != EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag)) { + if (1 != EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag.data())) { throw ParquetException("Couldn't get AES-GCM tag"); } @@ -227,45 +263,53 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(const uint8_t* plaintext, ciphertext[1] = static_cast(0xff & (buffer_size >> 8)); ciphertext[0] = static_cast(0xff & (buffer_size)); } - std::copy(nonce, nonce + kNonceLength, ciphertext + length_buffer_length_); - std::copy(tag, tag + kGcmTagLength, - ciphertext + length_buffer_length_ + kNonceLength + ciphertext_len); + std::copy(nonce.begin(), nonce.begin() + kNonceLength, + ciphertext.begin() + length_buffer_length_); + std::copy(tag.begin(), tag.end(), + ciphertext.begin() + length_buffer_length_ + kNonceLength + ciphertext_len); return length_buffer_length_ + buffer_size; } -int AesEncryptor::AesEncryptorImpl::CtrEncrypt(const uint8_t* plaintext, - int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* nonce, - uint8_t* ciphertext) { +int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, + span key, + span nonce, + span ciphertext) { int len; int ciphertext_len; + if (nonce.size() != static_cast(kNonceLength)) { + std::stringstream ss; + ss << "Invalid nonce size " << nonce.size() << ", expected " << kNonceLength; + throw ParquetException(ss.str()); + } + // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial // counter field. // The first 31 bits of the initial counter field are set to 0, the last bit // is set to 1. - uint8_t iv[kCtrIvLength]; - memset(iv, 0, kCtrIvLength); - std::copy(nonce, nonce + kNonceLength, iv); + std::array iv{}; + std::copy(nonce.begin(), nonce.begin() + kNonceLength, iv.begin()); iv[kCtrIvLength - 1] = 1; // Setting key and IV - if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key, iv)) { + if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key.data(), iv.data())) { throw ParquetException("Couldn't set key and IV"); } // Encryption - if (1 != EVP_EncryptUpdate(ctx_, ciphertext + length_buffer_length_ + kNonceLength, - &len, plaintext, plaintext_len)) { + if (1 != + EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, + &len, plaintext.data(), static_cast(plaintext.size()))) { throw ParquetException("Failed encryption update"); } ciphertext_len = len; // Finalization - if (1 != EVP_EncryptFinal_ex( - ctx_, ciphertext + length_buffer_length_ + kNonceLength + len, &len)) { + if (1 != + EVP_EncryptFinal_ex( + ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength + len, &len)) { throw ParquetException("Failed encryption finalization"); } @@ -279,29 +323,29 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(const uint8_t* plaintext, ciphertext[1] = static_cast(0xff & (buffer_size >> 8)); ciphertext[0] = static_cast(0xff & (buffer_size)); } - std::copy(nonce, nonce + kNonceLength, ciphertext + length_buffer_length_); + std::copy(nonce.begin(), nonce.begin() + kNonceLength, + ciphertext.begin() + length_buffer_length_); return length_buffer_length_ + buffer_size; } AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len, - const uint8_t* key, int key_len, const uint8_t* aad, - int aad_len, const uint8_t* nonce, - uint8_t* encrypted_footer) { - return impl_->SignedFooterEncrypt(footer, footer_len, key, key_len, aad, aad_len, nonce, - encrypted_footer); +int AesEncryptor::SignedFooterEncrypt(span footer, span key, + span aad, span nonce, + span encrypted_footer) { + return impl_->SignedFooterEncrypt(footer, key, aad, nonce, encrypted_footer); } void AesEncryptor::WipeOut() { impl_->WipeOut(); } -int AesEncryptor::CiphertextSizeDelta() { return impl_->ciphertext_size_delta(); } +int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { + return impl_->CiphertextLength(plaintext_len); +} -int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, - uint8_t* ciphertext) { - return impl_->Encrypt(plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext); +int AesEncryptor::Encrypt(span plaintext, span key, + span aad, span ciphertext) { + return impl_->Encrypt(plaintext, key, aad, ciphertext); } AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, @@ -341,6 +385,11 @@ class AesDecryptor::AesDecryptorImpl { std::stringstream ss; ss << "Negative plaintext length " << plaintext_len; throw ParquetException(ss.str()); + } else if (plaintext_len > std::numeric_limits::max() - ciphertext_size_delta_) { + std::stringstream ss; + ss << "Plaintext length " << plaintext_len << " plus ciphertext size delta " + << ciphertext_size_delta_ << " overflows int32"; + throw ParquetException(ss.str()); } return plaintext_len + ciphertext_size_delta_; } @@ -481,13 +530,16 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( } // Extract ciphertext length - int written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) | - ((ciphertext[2] & 0xff) << 16) | - ((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff)); + uint32_t written_ciphertext_len = (static_cast(ciphertext[3]) << 24) | + (static_cast(ciphertext[2]) << 16) | + (static_cast(ciphertext[1]) << 8) | + (static_cast(ciphertext[0])); - if (written_ciphertext_len < 0) { + if (written_ciphertext_len > + static_cast(std::numeric_limits::max() - length_buffer_length_)) { std::stringstream ss; - ss << "Negative ciphertext length " << written_ciphertext_len; + ss << "Written ciphertext length " << written_ciphertext_len + << " plus length buffer length " << length_buffer_length_ << " overflows int"; throw ParquetException(ss.str()); } else if (ciphertext.size() < static_cast(written_ciphertext_len) + length_buffer_length_) { @@ -499,11 +551,11 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( throw ParquetException(ss.str()); } - return written_ciphertext_len + length_buffer_length_; + return static_cast(written_ciphertext_len) + length_buffer_length_; } else { - if (ciphertext.size() > static_cast(std::numeric_limits::max())) { + if (ciphertext.size() > static_cast(std::numeric_limits::max())) { std::stringstream ss; - ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int32"; + ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int"; throw ParquetException(ss.str()); } return static_cast(ciphertext.size()); @@ -517,10 +569,8 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, int len; int plaintext_len; - uint8_t tag[kGcmTagLength]; - memset(tag, 0, kGcmTagLength); - uint8_t nonce[kNonceLength]; - memset(nonce, 0, kNonceLength); + std::array tag{}; + std::array nonce{}; int ciphertext_len = GetCiphertextLength(ciphertext); @@ -540,12 +590,12 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, // Extracting IV and tag std::copy(ciphertext.begin() + length_buffer_length_, - ciphertext.begin() + length_buffer_length_ + kNonceLength, nonce); + ciphertext.begin() + length_buffer_length_ + kNonceLength, nonce.begin()); std::copy(ciphertext.begin() + ciphertext_len - kGcmTagLength, - ciphertext.begin() + ciphertext_len, tag); + ciphertext.begin() + ciphertext_len, tag.begin()); // Setting key and IV - if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), nonce)) { + if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), nonce.data())) { throw ParquetException("Couldn't set key and IV"); } @@ -566,7 +616,7 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, plaintext_len = len; // Checking the tag (authentication) - if (!EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_SET_TAG, kGcmTagLength, tag)) { + if (!EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_SET_TAG, kGcmTagLength, tag.data())) { throw ParquetException("Failed authentication"); } @@ -585,8 +635,7 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, int len; int plaintext_len; - uint8_t iv[kCtrIvLength]; - memset(iv, 0, kCtrIvLength); + std::array iv{}; int ciphertext_len = GetCiphertextLength(ciphertext); @@ -606,7 +655,7 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, // Extracting nonce std::copy(ciphertext.begin() + length_buffer_length_, - ciphertext.begin() + length_buffer_length_ + kNonceLength, iv); + ciphertext.begin() + length_buffer_length_ + kNonceLength, iv.begin()); // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial // counter field. // The first 31 bits of the initial counter field are set to 0, the last bit @@ -614,7 +663,7 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, iv[kCtrIvLength - 1] = 1; // Setting key and IV - if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), iv)) { + if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key.data(), iv.data())) { throw ParquetException("Couldn't set key and IV"); } diff --git a/cpp/src/parquet/encryption/encryption_internal.h b/cpp/src/parquet/encryption/encryption_internal.h index 2d5450553c16d..a9a17f1ab98e3 100644 --- a/cpp/src/parquet/encryption/encryption_internal.h +++ b/cpp/src/parquet/encryption/encryption_internal.h @@ -61,18 +61,22 @@ class PARQUET_EXPORT AesEncryptor { ~AesEncryptor(); - /// Size difference between plaintext and ciphertext, for this cipher. - int CiphertextSizeDelta(); + /// The size of the ciphertext, for this cipher and the specified plaintext length. + [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const; /// Encrypts plaintext with the key and aad. Key length is passed only for validation. /// If different from value in constructor, exception will be thrown. - int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext); + int Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext); /// Encrypts plaintext footer, in order to compute footer signature (tag). - int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, - const uint8_t* nonce, uint8_t* encrypted_footer); + int SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer); void WipeOut(); diff --git a/cpp/src/parquet/encryption/encryption_internal_nossl.cc b/cpp/src/parquet/encryption/encryption_internal_nossl.cc index ed323c4aa6167..2f6cdc8200016 100644 --- a/cpp/src/parquet/encryption/encryption_internal_nossl.cc +++ b/cpp/src/parquet/encryption/encryption_internal_nossl.cc @@ -29,24 +29,26 @@ class AesEncryptor::AesEncryptorImpl {}; AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len, - const uint8_t* key, int key_len, const uint8_t* aad, - int aad_len, const uint8_t* nonce, - uint8_t* encrypted_footer) { +int AesEncryptor::SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer) { ThrowOpenSSLRequiredException(); return -1; } void AesEncryptor::WipeOut() { ThrowOpenSSLRequiredException(); } -int AesEncryptor::CiphertextSizeDelta() { +int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { ThrowOpenSSLRequiredException(); return -1; } -int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key, - int key_len, const uint8_t* aad, int aad_len, - uint8_t* ciphertext) { +int AesEncryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext) { ThrowOpenSSLRequiredException(); return -1; } diff --git a/cpp/src/parquet/encryption/encryption_internal_test.cc b/cpp/src/parquet/encryption/encryption_internal_test.cc index cf7eeef4c6446..22e14663ea81f 100644 --- a/cpp/src/parquet/encryption/encryption_internal_test.cc +++ b/cpp/src/parquet/encryption/encryption_internal_test.cc @@ -37,14 +37,12 @@ class TestAesEncryption : public ::testing::Test { AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length); - int expected_ciphertext_len = - static_cast(plain_text_.size()) + encryptor.CiphertextSizeDelta(); + int32_t expected_ciphertext_len = + encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = - encryptor.Encrypt(str2bytes(plain_text_), static_cast(plain_text_.size()), - str2bytes(key_), static_cast(key_.size()), str2bytes(aad_), - static_cast(aad_.size()), ciphertext.data()); + int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); ASSERT_EQ(ciphertext_length, expected_ciphertext_len); @@ -87,14 +85,12 @@ class TestAesEncryption : public ::testing::Test { AesEncryptor encryptor(cipher_type, key_length_, metadata, write_length); - int expected_ciphertext_len = - static_cast(plain_text_.size()) + encryptor.CiphertextSizeDelta(); + int32_t expected_ciphertext_len = + encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = - encryptor.Encrypt(str2bytes(plain_text_), static_cast(plain_text_.size()), - str2bytes(key_), static_cast(key_.size()), str2bytes(aad_), - static_cast(aad_.size()), ciphertext.data()); + int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.cc b/cpp/src/parquet/encryption/internal_file_encryptor.cc index 15bf52b84dd1b..a423cc678cccb 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_encryptor.cc @@ -31,12 +31,13 @@ Encryptor::Encryptor(encryption::AesEncryptor* aes_encryptor, const std::string& aad_(aad), pool_(pool) {} -int Encryptor::CiphertextSizeDelta() { return aes_encryptor_->CiphertextSizeDelta(); } +int32_t Encryptor::CiphertextLength(int64_t plaintext_len) const { + return aes_encryptor_->CiphertextLength(plaintext_len); +} -int Encryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext) { - return aes_encryptor_->Encrypt(plaintext, plaintext_len, str2bytes(key_), - static_cast(key_.size()), str2bytes(aad_), - static_cast(aad_.size()), ciphertext); +int Encryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext) { + return aes_encryptor_->Encrypt(plaintext, str2span(key_), str2span(aad_), ciphertext); } // InternalFileEncryptor diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.h b/cpp/src/parquet/encryption/internal_file_encryptor.h index 3cbe53500c2c5..41ffc6fd51943 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.h +++ b/cpp/src/parquet/encryption/internal_file_encryptor.h @@ -43,8 +43,10 @@ class PARQUET_EXPORT Encryptor { void UpdateAad(const std::string& aad) { aad_ = aad; } ::arrow::MemoryPool* pool() { return pool_; } - int CiphertextSizeDelta(); - int Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext); + [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const; + + int Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext); bool EncryptColumnMetaData( bool encrypted_footer, diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc b/cpp/src/parquet/encryption/key_toolkit_internal.cc index a3c7c996b130a..5d7925aa0318f 100644 --- a/cpp/src/parquet/encryption/key_toolkit_internal.cc +++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc @@ -32,15 +32,14 @@ std::string EncryptKeyLocally(const std::string& key_bytes, const std::string& m static_cast(master_key.size()), false, false /*write_length*/); - int encrypted_key_len = - static_cast(key_bytes.size()) + key_encryptor.CiphertextSizeDelta(); + int32_t encrypted_key_len = + key_encryptor.CiphertextLength(static_cast(key_bytes.size())); std::string encrypted_key(encrypted_key_len, '\0'); - encrypted_key_len = key_encryptor.Encrypt( - reinterpret_cast(key_bytes.data()), - static_cast(key_bytes.size()), - reinterpret_cast(master_key.data()), - static_cast(master_key.size()), reinterpret_cast(aad.data()), - static_cast(aad.size()), reinterpret_cast(&encrypted_key[0])); + ::arrow::util::span encrypted_key_span( + reinterpret_cast(&encrypted_key[0]), encrypted_key_len); + + encrypted_key_len = key_encryptor.Encrypt(str2span(key_bytes), str2span(master_key), + str2span(aad), encrypted_key_span); return ::arrow::util::base64_encode( ::std::string_view(encrypted_key.data(), encrypted_key_len)); diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index d7be50a6116bd..4ea3b05340d71 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -640,11 +640,13 @@ class FileMetaData::FileMetaDataImpl { uint32_t serialized_len = metadata_len_; ThriftSerializer serializer; serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data); + ::arrow::util::span serialized_data_span(serialized_data, + serialized_len); // encrypt with nonce - auto nonce = const_cast(reinterpret_cast(signature)); - auto tag = const_cast(reinterpret_cast(signature)) + - encryption::kNonceLength; + ::arrow::util::span nonce(reinterpret_cast(signature), + encryption::kNonceLength); + auto tag = reinterpret_cast(signature) + encryption::kNonceLength; std::string key = file_decryptor_->GetFooterKey(); std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad()); @@ -653,13 +655,11 @@ class FileMetaData::FileMetaDataImpl { file_decryptor_->algorithm(), static_cast(key.size()), true, false /*write_length*/, nullptr); - std::shared_ptr encrypted_buffer = std::static_pointer_cast( - AllocateBuffer(file_decryptor_->pool(), - aes_encryptor->CiphertextSizeDelta() + serialized_len)); + std::shared_ptr encrypted_buffer = AllocateBuffer( + file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len)); uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( - serialized_data, serialized_len, str2bytes(key), static_cast(key.size()), - str2bytes(aad), static_cast(aad.size()), nonce, - encrypted_buffer->mutable_data()); + serialized_data_span, str2span(key), str2span(aad), nonce, + encrypted_buffer->mutable_span_as()); // Delete AES encryptor object. It was created only to verify the footer signature. aes_encryptor->WipeOut(); delete aes_encryptor; @@ -701,12 +701,12 @@ class FileMetaData::FileMetaDataImpl { uint8_t* serialized_data; uint32_t serialized_len; serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data); + ::arrow::util::span serialized_data_span(serialized_data, + serialized_len); // encrypt the footer key - std::vector encrypted_data(encryptor->CiphertextSizeDelta() + - serialized_len); - unsigned encrypted_len = - encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data()); + std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); + int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); // write unencrypted footer PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len)); @@ -1559,11 +1559,11 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { serializer.SerializeToBuffer(&column_chunk_->meta_data, &serialized_len, &serialized_data); + ::arrow::util::span serialized_data_span(serialized_data, + serialized_len); - std::vector encrypted_data(encryptor->CiphertextSizeDelta() + - serialized_len); - unsigned encrypted_len = - encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data()); + std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); + int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); const char* temp = const_cast(reinterpret_cast(encrypted_data.data())); diff --git a/cpp/src/parquet/thrift_internal.h b/cpp/src/parquet/thrift_internal.h index 4e4d7ed9837df..b21b0e07afba2 100644 --- a/cpp/src/parquet/thrift_internal.h +++ b/cpp/src/parquet/thrift_internal.h @@ -417,8 +417,8 @@ class ThriftDeserializer { throw ParquetException(ss.str()); } // decrypt - auto decrypted_buffer = std::static_pointer_cast(AllocateBuffer( - decryptor->pool(), decryptor->PlaintextLength(static_cast(clen)))); + auto decrypted_buffer = AllocateBuffer( + decryptor->pool(), decryptor->PlaintextLength(static_cast(clen))); ::arrow::util::span cipher_buf(buf, clen); uint32_t decrypted_buffer_len = decryptor->Decrypt(cipher_buf, decrypted_buffer->mutable_span_as()); @@ -525,13 +525,13 @@ class ThriftSerializer { } } - int64_t SerializeEncryptedObj(ArrowOutputStream* out, uint8_t* out_buffer, + int64_t SerializeEncryptedObj(ArrowOutputStream* out, const uint8_t* out_buffer, uint32_t out_length, Encryptor* encryptor) { - auto cipher_buffer = std::static_pointer_cast(AllocateBuffer( - encryptor->pool(), - static_cast(encryptor->CiphertextSizeDelta() + out_length))); + auto cipher_buffer = + AllocateBuffer(encryptor->pool(), encryptor->CiphertextLength(out_length)); + ::arrow::util::span out_span(out_buffer, out_length); int cipher_buffer_len = - encryptor->Encrypt(out_buffer, out_length, cipher_buffer->mutable_data()); + encryptor->Encrypt(out_span, cipher_buffer->mutable_span_as()); PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len)); return static_cast(cipher_buffer_len);