Skip to content

Commit

Permalink
Port seal patch from SPU (#164)
Browse files Browse the repository at this point in the history
* Merge seal

* Add assert

* Update patch

* Update patch

* update

* revert
  • Loading branch information
anakinxc authored Aug 1, 2024
1 parent ac118f2 commit ff4787a
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 37 deletions.
252 changes: 252 additions & 0 deletions bazel/patches/seal.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1a7a2bfd..a27539fa 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -223,7 +223,7 @@ if(SEAL_USE_INTEL_HEXL)
message(STATUS "Intel HEXL: download ...")
seal_fetch_thirdparty_content(ExternalIntelHEXL)
else()
- find_package(HEXL 1.2.4)
+ find_package(HEXL 1.2)
if (NOT TARGET HEXL::hexl)
message(FATAL_ERROR "Intel HEXL: not found")
endif()
diff --git a/native/src/seal/context.cpp b/native/src/seal/context.cpp
index 887a1312..932d9774 100644
--- a/native/src/seal/context.cpp
+++ b/native/src/seal/context.cpp
@@ -477,7 +477,8 @@ namespace seal
// more than one modulus in coeff_modulus. This is equivalent to expanding
// the chain by one step. Otherwise, we set first_parms_id_ to equal
// key_parms_id_.
- if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1)
+ if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1 ||
+ !parms.use_special_prime())
{
first_parms_id_ = key_parms_id_;
}
diff --git a/native/src/seal/encryptionparams.cpp b/native/src/seal/encryptionparams.cpp
index 31e07441..6f8e6b2a 100644
--- a/native/src/seal/encryptionparams.cpp
+++ b/native/src/seal/encryptionparams.cpp
@@ -23,8 +23,10 @@ namespace seal
uint64_t poly_modulus_degree64 = static_cast<uint64_t>(poly_modulus_degree_);
uint64_t coeff_modulus_size64 = static_cast<uint64_t>(coeff_modulus_.size());
uint8_t scheme = static_cast<uint8_t>(scheme_);
+ uint8_t use_special_prime_size8 = static_cast<uint8_t>(use_special_prime_);

stream.write(reinterpret_cast<const char *>(&scheme), sizeof(uint8_t));
+ stream.write(reinterpret_cast<const char *>(&use_special_prime_size8), sizeof(uint8_t));
stream.write(reinterpret_cast<const char *>(&poly_modulus_degree64), sizeof(uint64_t));
stream.write(reinterpret_cast<const char *>(&coeff_modulus_size64), sizeof(uint64_t));
for (const auto &mod : coeff_modulus_)
@@ -63,6 +65,9 @@ namespace seal
// This constructor will throw if scheme is invalid
EncryptionParameters parms(scheme);

+ uint8_t use_special_prime_size8;
+ stream.read(reinterpret_cast<char *>(&use_special_prime_size8), sizeof(uint8_t));
+
// Read the poly_modulus_degree
uint64_t poly_modulus_degree64 = 0;
stream.read(reinterpret_cast<char *>(&poly_modulus_degree64), sizeof(uint64_t));
@@ -98,6 +103,7 @@ namespace seal
// Supposedly everything worked so set the values of member variables
parms.set_poly_modulus_degree(safe_cast<size_t>(poly_modulus_degree64));
parms.set_coeff_modulus(coeff_modulus);
+ parms.set_use_special_prime(use_special_prime_size8);

// Only BFV and BGV uses plain_modulus; set_plain_modulus checks that for
// other schemes it is zero
@@ -128,6 +134,7 @@ namespace seal
size_t total_uint64_count = add_safe(
size_t(1), // scheme
size_t(1), // poly_modulus_degree
+ size_t(1), // use_special_prime
coeff_modulus_size, plain_modulus_.uint64_count());

auto param_data(allocate_uint(total_uint64_count, pool_));
@@ -139,6 +146,7 @@ namespace seal
// Write the poly_modulus_degree. Note that it will always be positive.
*param_data_ptr++ = static_cast<uint64_t>(poly_modulus_degree_);

+ *param_data_ptr++ = static_cast<uint64_t>(use_special_prime_);
for (const auto &mod : coeff_modulus_)
{
*param_data_ptr++ = mod.value();
diff --git a/native/src/seal/encryptionparams.h b/native/src/seal/encryptionparams.h
index 9e1fbe48..8530eeeb 100644
--- a/native/src/seal/encryptionparams.h
+++ b/native/src/seal/encryptionparams.h
@@ -266,6 +266,11 @@ namespace seal
random_generator_ = std::move(random_generator);
}

+ inline void set_use_special_prime(bool flag)
+ {
+ use_special_prime_ = flag;
+ }
+
/**
Returns the encryption scheme type.
*/
@@ -274,6 +279,11 @@ namespace seal
return scheme_;
}

+ bool use_special_prime() const noexcept
+ {
+ return use_special_prime_;
+ }
+
/**
Returns the degree of the polynomial modulus parameter.
*/
@@ -360,6 +370,7 @@ namespace seal
std::size_t members_size = Serialization::ComprSizeEstimate(
util::add_safe(
sizeof(scheme_),
+ sizeof(use_special_prime_),
sizeof(std::uint64_t), // poly_modulus_degree_
sizeof(std::uint64_t), // coeff_modulus_size
coeff_modulus_total_size,
@@ -501,6 +512,8 @@ namespace seal

Modulus plain_modulus_{};

+ bool use_special_prime_ = true;
+
parms_id_type parms_id_ = parms_id_zero;
};
} // namespace seal
diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp
index dabd3bab..61a96ae9 100644
--- a/native/src/seal/evaluator.cpp
+++ b/native/src/seal/evaluator.cpp
@@ -2382,6 +2382,7 @@ namespace seal
size_t encrypted_size = encrypted.size();
// Use key_context_data where permutation tables exist since previous runs.
auto galois_tool = context_.key_context_data()->galois_tool();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Size check
if (!product_fits_in(coeff_count, coeff_modulus_size))
@@ -2412,7 +2413,7 @@ namespace seal
// DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION
// BEGIN: Apply Galois for each ciphertext
// Execution order is sensitive, since apply_galois is not inplace!
- if (parms.scheme() == scheme_type::bfv)
+ if (not is_ntt_form)
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2426,7 +2427,7 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp);
}
- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv)
+ else
{
// !!! DO NOT CHANGE EXECUTION ORDER!!!

@@ -2440,10 +2441,6 @@ namespace seal
// Next transform encrypted.data(1)
galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp);
}
- else
- {
- throw logic_error("scheme not implemented");
- }

// Wipe encrypted.data(1)
set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1));
@@ -2530,6 +2527,7 @@ namespace seal
auto &key_context_data = *context_.key_context_data();
auto &key_parms = key_context_data.parms();
auto scheme = parms.scheme();
+ bool is_ntt_form = encrypted.is_ntt_form();

// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
@@ -2559,14 +2557,6 @@ namespace seal
{
throw invalid_argument("pool is uninitialized");
}
- if (scheme == scheme_type::bfv && encrypted.is_ntt_form())
- {
- throw invalid_argument("BFV encrypted cannot be in NTT form");
- }
- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form())
- {
- throw invalid_argument("CKKS encrypted must be in NTT form");
- }
if (scheme == scheme_type::bgv && !encrypted.is_ntt_form())
{
throw invalid_argument("BGV encrypted must be in NTT form");
@@ -2605,7 +2595,7 @@ namespace seal
set_uint(target_iter, decomp_modulus_size * coeff_count, t_target);

// In CKKS or BGV, t_target is in NTT form; switch back to normal form
- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv)
+ if (is_ntt_form)
{
inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables);
}
@@ -2632,7 +2622,7 @@ namespace seal
ConstCoeffIter t_operand;

// RNS-NTT form exists in input
- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J))
+ if (is_ntt_form && (I == J))
{
t_operand = target_iter[J];
}
@@ -2789,7 +2779,7 @@ namespace seal
SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; });

uint64_t qi_lazy = qi << 1; // some multiples of qi
- if (scheme == scheme_type::ckks)
+ if (is_ntt_form)
{
// This ntt_negacyclic_harvey_lazy results in [0, 4*qi).
ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J));
@@ -2802,7 +2792,7 @@ namespace seal
qi_lazy = qi << 2;
#endif
}
- else if (scheme == scheme_type::bfv)
+ else
{
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
}
diff --git a/native/src/seal/evaluator.h b/native/src/seal/evaluator.h
index 9e3dd576..bb598ddf 100644
--- a/native/src/seal/evaluator.h
+++ b/native/src/seal/evaluator.h
@@ -1355,10 +1355,12 @@ namespace seal
apply_galois_inplace(encrypted, galois_tool->get_elt_from_step(0), galois_keys, std::move(pool));
}

+ public:
void switch_key_inplace(
Ciphertext &encrypted, util::ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys,
std::size_t key_index, MemoryPoolHandle pool = MemoryManager::GetPool()) const;

+ private:
void multiply_plain_normal(Ciphertext &encrypted, const Plaintext &plain, MemoryPoolHandle pool) const;

void multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const;
diff --git a/native/src/seal/serializable.h b/native/src/seal/serializable.h
index a940190c..e490b302 100644
--- a/native/src/seal/serializable.h
+++ b/native/src/seal/serializable.h
@@ -135,6 +135,9 @@ namespace seal
return obj_.save(out, size, compr_mode);
}

+ const T& obj() const { return obj_; }
+
+ T& obj() { return obj_; }
private:
Serializable(T &&obj) : obj_(std::move(obj))
{}
8 changes: 4 additions & 4 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def _com_github_microsoft_seal():
sha256 = "af9bf0f0daccda2a8b7f344f13a5692e0ee6a45fea88478b2b90c35648bf2672",
strip_prefix = "SEAL-4.1.1",
type = "tar.gz",
patch_args = ["-p1"],
patches = ["@psi//bazel:patches/seal.patch"],
urls = [
"https://github.com/microsoft/SEAL/archive/refs/tags/v4.1.1.tar.gz",
],
Expand Down Expand Up @@ -228,13 +230,11 @@ def _com_google_flatbuffers():
urls = [
"https://github.com/google/flatbuffers/archive/refs/tags/v24.3.25.tar.gz",
],
patches = [
"@psi//bazel:patches/flatbuffers.patch",
],
patch_args = ["-p1"],
patch_cmds = [
# hack to make sure this file is removed
"rm grpc/BUILD.bazel",
"rm grpc/src/compiler/BUILD.bazel",
"rm src/BUILD.bazel",
],
build_file = "@psi//bazel:flatbuffers.BUILD",
)
Expand Down
47 changes: 16 additions & 31 deletions bazel/seal.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,30 @@ filegroup(
srcs = glob(["**"]),
)

config_setting(
name = "can_use_hexl",
constraint_values = [
"@platforms//cpu:x86_64",
],
values = {"compilation_mode": "opt"},
)

default_config = {
"SEAL_USE_MSGSL": "ON",
"SEAL_BUILD_DEPS": "OFF",
"SEAL_USE_ZLIB": "ON",
"SEAL_USE_INTEL_HEXL": "OFF",
"SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT": "OFF", #NOTE(juhou) required by apsi
"SEAL_USE_ZSTD": "ON",
"CMAKE_INSTALL_LIBDIR": "lib",
}

x64_hexl_config = {
"SEAL_USE_MSGSL": "ON",
"SEAL_BUILD_DEPS": "OFF",
"SEAL_USE_ZLIB": "ON",
"SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT": "OFF", #NOTE(juhou) required by apsi
"CMAKE_INSTALL_LIBDIR": "lib",
"CpuFeatures_DIR": "$EXT_BUILD_DEPS/cpu_features/lib/cmake/CpuFeatures/",
"EXT_BUILD_DEPS": "$EXT_BUILD_DEPS",
"SEAL_USE_ZSTD": "ON",
"SEAL_USE_INTEL_HEXL": "ON",
}

psi_cmake_external(
name = "seal",
cache_entries = default_config,
cache_entries = {
"SEAL_USE_MSGSL": "ON",
"SEAL_BUILD_DEPS": "OFF",
"SEAL_USE_ZLIB": "ON",
"SEAL_USE_INTEL_HEXL": "OFF",
"SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT": "OFF", #NOTE(juhou) required by apsi
"SEAL_USE_ZSTD": "ON",
"CMAKE_INSTALL_LIBDIR": "lib",
# Uncomment to use hexl
# "CpuFeatures_DIR": "$EXT_BUILD_DEPS/cpu_features/lib/cmake/CpuFeatures/",
# "EXT_BUILD_DEPS": "$EXT_BUILD_DEPS",
# "SEAL_USE_INTEL_HEXL": "ON",

},
lib_source = "@com_github_microsoft_seal//:all",
out_include_dir = "include/SEAL-4.1",
out_static_libs = ["libseal-4.1.a"],
deps = [
"@com_github_facebook_zstd//:zstd",
"@com_github_microsoft_gsl//:Microsoft.GSL",
"@zlib",
# Uncomment to use hexl
# "@com_intel_hexl//:hexl"
],
)
4 changes: 2 additions & 2 deletions psi/apsi_wrapper/api/api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ TEST(ApiTest, Works) {

receiver.ProcessResult(query_response_str, receiver_output_file);

sender.SaveSenderDb(sdb_out_file);
ASSERT_TRUE(sender.SaveSenderDb(sdb_out_file));
}

{
psi::apsi_wrapper::api::Sender sender;
sender.LoadSenderDb(sdb_out_file);
ASSERT_TRUE(sender.LoadSenderDb(sdb_out_file));
std::string params_str = sender.GenerateParams();

psi::apsi_wrapper::api::Receiver receiver;
Expand Down

0 comments on commit ff4787a

Please sign in to comment.