Skip to content

Commit

Permalink
Merge pull request #14 from lightbulb128/batch_ops
Browse files Browse the repository at this point in the history
Batch ops
  • Loading branch information
lightbulb128 authored Oct 16, 2024
2 parents 14304e8 + fdf6faa commit 3354734
Show file tree
Hide file tree
Showing 29 changed files with 1,999 additions and 313 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ if (TROY_MEMORY_POOL_UNSAFE)
endif()
endif()

set(TROY_PYBIND11_GIT_HASH "8a099e44b3d5f85b20f05828d919d2332a8de841")
set(TROY_PYBIND11_GIT_HASH "a2e59f0e7065404b44dfe92a28aca47ba1378dc4")

if(TROY_PYBIND)
directory_nonexistent_or_empty("${CMAKE_CURRENT_SOURCE_DIR}/extern/pybind11" TROY_PYBIND11_NOT_FOUND)
Expand Down
2 changes: 1 addition & 1 deletion extern/pybind11
Submodule pybind11 updated 193 files
2 changes: 1 addition & 1 deletion pybind/develop.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# exit when error
set -e
set -ex

cd ../build
cmake .. -DTROY_PYBIND=ON -DCMAKE_BUILD_TYPE=Release
Expand Down
6 changes: 6 additions & 0 deletions pybind/src/ciphertext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ void register_ciphertext(pybind11::module& m) {

py::class_<Ciphertext>(m, "Ciphertext")
.def(py::init<>())
.def("address", [](const Ciphertext& self){
return reinterpret_cast<uintptr_t>(&self);
})
.def("data_address", [](const Ciphertext& self){
return reinterpret_cast<uintptr_t>(self.data().raw_pointer());
})
.def("pool", &Ciphertext::pool)
.def("device_index", &Ciphertext::device_index)
.def("obtain_data", [](const Ciphertext& self){
Expand Down
55 changes: 55 additions & 0 deletions pybind/src/evaluator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ void register_evaluator(pybind11::module& m) {
.def("multiply_plain_new", [](const Evaluator& self, const Ciphertext& encrypted, const Plaintext& plain, MemoryPoolHandleArgument pool) {
return self.multiply_plain_new(encrypted, plain, nullopt_default_pool(pool));
}, py::arg("encrypted"), py::arg("plain"), MEMORY_POOL_ARGUMENT)
.def("multiply_plain_new_batched", [](const Evaluator& self, const py::list& encrypted, const py::list& plain, MemoryPoolHandleArgument pool) {
return self.multiply_plain_new_batched(
cast_list<const Ciphertext*>(encrypted), cast_list<const Plaintext*>(plain),
nullopt_default_pool(pool)
);
}, py::arg("encrypted"), py::arg("plain"), MEMORY_POOL_ARGUMENT)

// transform_plain_to_ntt(plain, parms_id, destination, pool)
.def("transform_plain_to_ntt", [](const Evaluator& self, const Plaintext& plain, ParmsID parms_id, Plaintext& destination, MemoryPoolHandleArgument pool) {
Expand Down Expand Up @@ -282,6 +288,7 @@ void register_evaluator(pybind11::module& m) {
return self.complex_conjugate_new(encrypted, galois_keys, nullopt_default_pool(pool));
}, py::arg("encrypted"), py::arg("galois_keys"), MEMORY_POOL_ARGUMENT)

// packlwes related
.def("extract_lwe_new", [](const Evaluator& self, const Ciphertext& encrypted, size_t term, MemoryPoolHandleArgument pool) {
return self.extract_lwe_new(encrypted, term, nullopt_default_pool(pool));
}, py::arg("encrypted"), py::arg("term"), MEMORY_POOL_ARGUMENT)
Expand All @@ -301,6 +308,54 @@ void register_evaluator(pybind11::module& m) {
.def("pack_lwe_ciphertexts_new", [](const Evaluator& self, const std::vector<LWECiphertext>& lwe_ciphertexts, const GaloisKeys& automorphism_keys, MemoryPoolHandleArgument pool) {
return self.pack_lwe_ciphertexts_new(lwe_ciphertexts, automorphism_keys, nullopt_default_pool(pool));
}, py::arg("lwe_ciphertexts"), py::arg("automorphism_keys"), MEMORY_POOL_ARGUMENT)
.def("pack_lwe_ciphertexts_new", [](const Evaluator& self, const py::list& lwe_ciphertexts, const GaloisKeys& automorphism_keys, MemoryPoolHandleArgument pool) {
return self.pack_lwe_ciphertexts_new(cast_list<const LWECiphertext*>(lwe_ciphertexts), automorphism_keys, nullopt_default_pool(pool));
}, py::arg("lwe_ciphertexts"), py::arg("automorphism_keys"), MEMORY_POOL_ARGUMENT)
.def("pack_rlwe_ciphertexts_new", [](
const Evaluator& self, const py::list& rlwe_ciphertexts,
const GaloisKeys& automorphism_keys,
size_t shift, size_t input_interval, size_t output_interval,
MemoryPoolHandleArgument pool
) {
return self.pack_rlwe_ciphertexts_new(
cast_list<const Ciphertext*>(rlwe_ciphertexts),
automorphism_keys, shift, input_interval, output_interval,
nullopt_default_pool(pool)
);
},
py::arg("rlwe_ciphertexts"), py::arg("automorphism_keys"),
py::arg("shift"), py::arg("input_interval"), py::arg("output_interval"),
MEMORY_POOL_ARGUMENT
)

.def("pack_lwe_ciphertexts_new_batched", [](const Evaluator& self, const py::list& lwe_groups, const GaloisKeys& automorphism_keys, MemoryPoolHandleArgument pool) {
std::vector<std::vector<const LWECiphertext*>> cvv(lwe_groups.size());
for (size_t i = 0; i < lwe_groups.size(); i++) {
py::list pv = lwe_groups[i].cast<py::list>();
cvv[i] = cast_list<const LWECiphertext*>(pv);
}
return self.pack_lwe_ciphertexts_new_batched(cvv, automorphism_keys, nullopt_default_pool(pool));
}, py::arg("lwe_groups"), py::arg("automorphism_keys"), MEMORY_POOL_ARGUMENT)
.def("pack_rlwe_ciphertexts_new_batched", [](
const Evaluator& self, const py::list& rlwe_groups,
const GaloisKeys& automorphism_keys,
size_t shift, size_t input_interval, size_t output_interval,
MemoryPoolHandleArgument pool
) {
std::vector<std::vector<const Ciphertext*>> cvv(rlwe_groups.size());
for (size_t i = 0; i < rlwe_groups.size(); i++) {
py::list pv = rlwe_groups[i].cast<py::list>();
cvv[i] = cast_list<const Ciphertext*>(pv);
}
return self.pack_rlwe_ciphertexts_new_batched(
cvv, automorphism_keys, shift, input_interval, output_interval,
nullopt_default_pool(pool)
);
},
py::arg("rlwe_groups"), py::arg("automorphism_keys"),
py::arg("shift"), py::arg("input_interval"), py::arg("output_interval"),
MEMORY_POOL_ARGUMENT
)

// negacyclic_shift(encrypted, size_t shift, destination, pool)
.def("negacyclic_shift", [](const Evaluator& self, const Ciphertext& encrypted, size_t shift, Ciphertext& destination, MemoryPoolHandleArgument pool) {
Expand Down
9 changes: 9 additions & 0 deletions pybind/src/header.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ size_t serialized_size_upperbound_he(const T& object, HeContextPointer context,
return object.serialized_size_upperbound(context, mode);
}

template <typename T>
std::vector<T> cast_list(const py::list& list) {
std::vector<T> vec; vec.reserve(list.size());
for (const auto& item : list) {
vec.push_back(item.cast<T>());
}
return vec;
}

typedef std::optional<MemoryPoolHandle> MemoryPoolHandleArgument;

inline MemoryPoolHandle nullopt_default_pool(MemoryPoolHandleArgument pool) {
Expand Down
3 changes: 3 additions & 0 deletions pybind/src/lwe_ciphertext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ void register_lwe_ciphertext(pybind11::module& m) {

py::class_<LWECiphertext>(m, "LWECiphertext")
.def("pool", &LWECiphertext::pool)
.def("address", [](const LWECiphertext& self){
return reinterpret_cast<uintptr_t>(&self);
})
.def("device_index", &LWECiphertext::device_index)
.def("clone", [](const LWECiphertext& self, MemoryPoolHandleArgument pool){
return self.clone(nullopt_default_pool(pool));
Expand Down
6 changes: 6 additions & 0 deletions pybind/src/plaintext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ void register_plaintext(pybind11::module& m) {

py::class_<Plaintext>(m, "Plaintext")
.def(py::init<>())
.def("address", [](const Plaintext& self){
return reinterpret_cast<uintptr_t>(&self);
})
.def("data_address", [](const Plaintext& self){
return reinterpret_cast<uintptr_t>(self.data().raw_pointer());
})
.def("obtain_data", [](const Plaintext& self){
troy::utils::DynamicArray<uint64_t> data = self.data().to_host();
return get_buffer_from_slice(data.const_reference());
Expand Down
6 changes: 3 additions & 3 deletions pybind/tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
import pytroy
from pytroy import CoeffModulus, SchemeType
from pytroy import Modulus, PlainModulus, EncryptionParameters
from pytroy import BatchEncoder, CKKSEncoder, ParmsID, Plaintext, Ciphertext, HeContext, SecurityLevel
from pytroy import KeyGenerator, Encryptor, Decryptor, Evaluator, RelinKeys, GaloisKeys, MemoryPool
from pytroy import PlainModulus, EncryptionParameters
from pytroy import BatchEncoder, HeContext
from pytroy import KeyGenerator, Encryptor, MemoryPool

class Basics(unittest.TestCase):

Expand Down
18 changes: 15 additions & 3 deletions pybind/tests/test_he_operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytroy
from pytroy import Modulus, CoeffModulus, PlainModulus, EncryptionParameters, SchemeType
from pytroy import BatchEncoder, CKKSEncoder, ParmsID, Plaintext, Ciphertext, HeContext, SecurityLevel
from pytroy import KeyGenerator, Encryptor, Decryptor, Evaluator, RelinKeys, GaloisKeys
from pytroy import SchemeType
from pytroy import Ciphertext
from pytroy import KeyGenerator, Encryptor
import typing
import unittest
import numpy as np
Expand Down Expand Up @@ -203,6 +203,18 @@ def test_multiply_plain(self):
decoded = ghe.encoder.decode_simd(ghe.decryptor.decrypt_new(multiplied))
self.tester.assertTrue(ghe.near_equal(ghe.mul(message1, message2), decoded))

# test multiply plain batched
batch_size = 16
message1 = [ghe.random_simd_full() for _ in range(batch_size)]
message2 = [ghe.random_simd_full() for _ in range(batch_size)]
plain1 = [ghe.encoder.encode_simd(message1[i]) for i in range(batch_size)]
plain2 = [ghe.encoder.encode_simd(message2[i]) for i in range(batch_size)]
cipher1 = [ghe.encryptor.encrypt_symmetric_new(plain1[i], False) for i in range(batch_size)]
multiplied = ghe.evaluator.multiply_plain_new_batched(cipher1, plain2)
decoded = [ghe.encoder.decode_simd(ghe.decryptor.decrypt_new(multiplied[i])) for i in range(batch_size)]
for i in range(batch_size):
self.tester.assertTrue(ghe.near_equal(ghe.mul(message1[i], message2[i]), decoded[i]))

def test_rotate(self):
ghe = self.ghe
message = ghe.random_simd_full()
Expand Down
76 changes: 26 additions & 50 deletions src/app/matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -539,76 +539,52 @@ namespace troy { namespace linear {
D_IMPL_ALL
#undef D_IMPL

Cipher2d MatmulHelper::pack_outputs(const Evaluator& evaluator, const GaloisKeys& autoKey, const Cipher2d& cipher) const {
Cipher2d MatmulHelper::pack_outputs(const Evaluator& evaluator, const GaloisKeys& auto_key, const Cipher2d& cipher) const {
if (!this->pack_lwe) {
throw std::invalid_argument("[MatmulHelper::packOutputs] PackLwe not enabled");
}
if (cipher.data().size() == 0 || cipher.data()[0].size() == 0) {
Cipher2d ret; ret.data().push_back(std::vector<Ciphertext>());
return ret;
}
size_t packSlots = this->input_block;
size_t totalCount = cipher.data().size() * cipher.data()[0].size();
std::vector<Ciphertext> output; output.reserve(ceil_div(totalCount, packSlots));
Ciphertext current; bool currentSet = false;
size_t currentSlot = 0;
size_t pack_slots = this->input_block;
size_t total_count = cipher.data().size() * cipher.data()[0].size();
std::vector<Ciphertext> output;

bool is_ntt = cipher.data()[0][0].is_ntt_form();

size_t field_trace_logn = 0;
size_t field_trace_n = 1;
while (field_trace_n != slot_count / packSlots) {
field_trace_logn += 1;
while (field_trace_n != slot_count / pack_slots) {
field_trace_n *= 2;
}

Ciphertext buffer = cipher.data()[0][0].clone(pool);
Ciphertext shifted = buffer.clone(pool);
size_t inherent_shift = pack_slots == 1 ? 0 : 2 * slot_count - (pack_slots - 1);

std::vector<std::vector<const Ciphertext*>> to_pack; to_pack.reserve(ceil_div(total_count, pack_slots));
to_pack.push_back(std::vector<const Ciphertext*>()); to_pack.back().reserve(pack_slots);
for (size_t i = 0; i < cipher.data().size(); i++) {
for (size_t j = 0; j < cipher.data()[0].size(); j++) {
size_t shift = packSlots - 1;
Ciphertext ciphertext = cipher.data()[i][j].clone(pool);
if (is_ntt) evaluator.transform_from_ntt_inplace(ciphertext);
if (shift != 0) {
evaluator.negacyclic_shift(ciphertext, 2 * slot_count - shift, buffer, pool);
} else {
buffer = ciphertext.clone(pool);
}

evaluator.divide_by_poly_modulus_degree_inplace(buffer, slot_count / packSlots);
if (is_ntt) evaluator.transform_to_ntt_inplace(buffer);

evaluator.field_trace_inplace(buffer, autoKey, field_trace_logn, pool);
if (is_ntt) evaluator.transform_from_ntt_inplace(buffer);

shift = currentSlot;
if (shift != 0) {
evaluator.negacyclic_shift(buffer, shift, shifted, pool);
} else {
shifted = buffer.clone(pool);
}

if (currentSet == false) {
current = shifted.clone(pool);
currentSet = true;
} else {
evaluator.add_inplace(current, shifted, pool);
}

currentSlot += 1;
if (currentSlot == packSlots) {
currentSlot = 0; currentSet = false;
output.push_back(std::move(current));
if (to_pack.size() == 0 || to_pack.back().size() == pack_slots) {
to_pack.push_back(std::vector<const Ciphertext*>()); to_pack.back().reserve(pack_slots);
}
to_pack.back().push_back(&cipher.data()[i][j]);
}
}
if (currentSet) {
output.push_back(std::move(current));
}
if (is_ntt) for (Ciphertext& c : output) {
evaluator.transform_to_ntt_inplace(c);

if (!batched_mul) {
output.reserve(ceil_div(total_count, pack_slots));
for (size_t i = 0; i < to_pack.size(); i++) {
output.push_back(evaluator.pack_rlwe_ciphertexts_new(
to_pack[i], auto_key, inherent_shift, input_block, 1, pool
));
}
} else {
output = evaluator.pack_rlwe_ciphertexts_new_batched(
to_pack, auto_key, inherent_shift, input_block, 1, pool
);
}
Cipher2d ret; ret.data().push_back(output);

Cipher2d ret; ret.data().push_back(std::move(output));
return ret;
}

Expand Down
4 changes: 1 addition & 3 deletions src/app/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ namespace troy { namespace linear {

// When this is enabled, in the `matmul` function
// we will invoke the batched operation. Note that this will
// consume much more memory than the non-batched version.
// Specifically, the memory is O(MNR/n), where n is the poly degree.
// If this is disabled, the memory requirement should be O(max(MN, NR)/n)
// consume more memory (presumably by a ratio of O(1) constant) than the non-batched version.
bool batched_mul = false;

inline void set_pool(MemoryPoolHandle pool) {
Expand Down
Loading

0 comments on commit 3354734

Please sign in to comment.