Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 72 additions & 34 deletions tools/ckks_client_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// * Load global CKKS params + public key (--context <prefix>).
// * Generate random 128-bit AES key -> <client>.aes.
// * Generate Ed25519 keypair -> <client>.ed25519.{priv,pub}.
// * Encode AES key bits in CKKS slots; encrypt w/ global pk.
// * Serialize ciphertext (level+1 polys) -> <client>.aes.ckks.ct.
// * Encrypt each AES key bit individually (128 ciphertexts).
// * Serialize all 128 ciphertexts to single file -> <client>.aes.ckks.ct.
// * SHA-256 hash ciphertext; sign hash w/ Ed25519 priv -> .sig.
//
// Optional:
Expand Down Expand Up @@ -116,7 +116,7 @@ PublicKey loadPub(const string &path) {
return PublicKey(As, Bs);
}

// Ciphertext writer: (level+1) polys; NTT domain.
// Single ciphertext writer: (level+1) polys; NTT domain.
void saveCiphertext(Ciphertext &ct, const string &path, double scale) {
ofstream ofs(path, ios::binary);
if (!ofs) throw runtime_error("open ct for write failed");
Expand Down Expand Up @@ -145,6 +145,47 @@ void saveCiphertext(Ciphertext &ct, const string &path, double scale) {
if (!ofs) throw runtime_error("write ct failed");
}

// Multi-ciphertext writer: save 128 ciphertexts (one per bit) in a single file
void saveMultiCiphertext(vector<Ciphertext> &cts, const string &path, double scale) {
if (cts.size() != 128) throw runtime_error("expected 128 ciphertexts");

ofstream ofs(path, ios::binary);
if (!ofs) throw runtime_error("open multi-ct for write failed");

// Write header
ofs.write(reinterpret_cast<const char*>(&CT_MAGIC), sizeof(uint32_t));
ofs.write(reinterpret_cast<const char*>(&VERSION), sizeof(uint32_t));
ofs.write(reinterpret_cast<const char*>(&scale), sizeof(double));

// Write number of ciphertexts
uint64_t numCts = 128;
ofs.write(reinterpret_cast<const char*>(&numCts), sizeof(uint64_t));

// Write each ciphertext
for (auto &ct : cts) {
int level = ct.getLevel();
uint64_t lvl = static_cast<uint64_t>(level);
ofs.write(reinterpret_cast<const char*>(&lvl), sizeof(uint64_t));

auto &As = ct.getPolysA();
auto &Bs = ct.getPolysB();
size_t need = static_cast<size_t>(level) + 1;
if (As.size() < need || Bs.size() < need)
throw runtime_error("ciphertext vectors shorter than level+1");

for (size_t i = 0; i < need; ++i) {
const auto &v0 = As[i].getCoeffs();
const auto &v1 = Bs[i].getCoeffs();
uint64_t n0 = v0.size(), n1 = v1.size();
ofs.write(reinterpret_cast<const char*>(&n0), sizeof(uint64_t));
ofs.write(reinterpret_cast<const char*>(v0.data()), n0*sizeof(uint64_t));
ofs.write(reinterpret_cast<const char*>(&n1), sizeof(uint64_t));
ofs.write(reinterpret_cast<const char*>(v1.data()), n1*sizeof(uint64_t));
}
}
if (!ofs) throw runtime_error("write multi-ct failed");
}

} // namespace ser

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -235,21 +276,6 @@ static vector<unsigned char> sha256_file(const string& path) {
return digest;
}

//------------------------------------------------------------------------------
// Bit packing: AES bytes -> 128 {0,1} slots (LSB-first per byte)
//------------------------------------------------------------------------------
static vector<complex<double>> aeskey_to_bitslots(const unsigned char* k16) {
vector<complex<double>> slots;
slots.reserve(128);
for (int byte = 0; byte < 16; ++byte) {
unsigned char b = k16[byte];
for (int bit = 0; bit < 8; ++bit) {
int v = (b >> bit) & 1;
slots.emplace_back(static_cast<double>(v), 0.0);
}
}
return slots;
}

//------------------------------------------------------------------------------
// Memory usage (optional)
Expand Down Expand Up @@ -352,35 +378,47 @@ int main(int argc, char** argv) {
printf("Generated Ed25519 keypair -> %s (priv), %s (pub)\n",
privPemPath.c_str(), pubPemPath.c_str());

// CKKS encode + encrypt
// CKKS encode + encrypt each bit individually
Encoder encoder(params);
Encryptor encryptor(params);

auto slots = aeskey_to_bitslots(aes.data());
vector<Ciphertext> bitCiphertexts;
bitCiphertexts.reserve(128);

auto t0 = high_resolution_clock::now();
Plaintext pt = encoder.encode(slots);

// Encrypt each bit individually
// TODO(mpang): there are probably ways in C to make sure we don't allocate more memory than we'd need (by writing the ciphertext to the file one at a time)
// The current way of doing things is probably not memory efficient
for (int byte = 0; byte < 16; ++byte) {
unsigned char b = aes[byte];
for (int bit = 0; bit < 8; ++bit) {
int bitVal = (b >> bit) & 1;
vector<complex<double>> singleBitSlot(1, complex<double>(static_cast<double>(bitVal), 0.0));

Plaintext pt = encoder.encode(singleBitSlot);
Ciphertext ct = encryptor.encrypt(pt, pub);
bitCiphertexts.push_back(ct);
}
}

auto t1 = high_resolution_clock::now();
print_mem("after encode", rssBase);
print_mem("after encrypt all bits", rssBase);

Ciphertext ct = encryptor.encrypt(pt, pub);
// Save all ciphertexts in one file
ser::saveMultiCiphertext(bitCiphertexts, ctPath, static_cast<double>(params.getScale()));
auto t2 = high_resolution_clock::now();
print_mem("after encrypt", rssBase);

ser::saveCiphertext(ct, ctPath, static_cast<double>(params.getScale()));
auto t3 = high_resolution_clock::now();
print_mem("after serialize", rssBase);

ifstream tmp(ctPath, ios::binary | ios::ate);
long long ctsz = tmp ? static_cast<long long>(tmp.tellg()) : -1;

auto enc_us = duration_cast<microseconds>(t1 - t0).count();
auto enc_ms = duration_cast<milliseconds>(t2 - t1).count();
auto ser_us = duration_cast<microseconds>(t3 - t2).count();
auto enc_ms = duration_cast<milliseconds>(t1 - t0).count();
auto ser_ms = duration_cast<milliseconds>(t2 - t1).count();

printf("CKKS encode: %lld us\n", (long long)enc_us);
printf("CKKS encrypt: %lld ms\n", (long long)enc_ms);
printf("CKKS serialize: %lld us\n", (long long)ser_us);
printf("Ciphertext size: %lld bytes -> %s\n", ctsz, ctPath.c_str());
printf("CKKS encrypt 128 bits: %lld ms\n", (long long)enc_ms);
printf("CKKS serialize: %lld ms\n", (long long)ser_ms);
printf("Total ciphertext size: %lld bytes -> %s\n", ctsz, ctPath.c_str());

// Sign ciphertext hash
auto hash = sha256_file(ctPath);
Expand Down
12 changes: 3 additions & 9 deletions tools/ckks_server_prep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,19 @@ void saveSec(SecretKey &sk, const string &path) {

// --------------------------- Parameter Set ----------------------------------
static vector<uint64_t> DEMO_Q = {
2199028891649ULL,
1099512938497ULL, 1099515691009ULL, 1099516870657ULL, 1099521458177ULL,
1099522375681ULL, 1099523555329ULL, 1099525128193ULL, 1099526176769ULL,
1099529060353ULL, 1099535220737ULL, 1099536138241ULL, 1099537580033ULL,
1099538104321ULL, 1099540725761ULL, 1099540856833ULL, 1099543085057ULL,
1099544002561ULL, 1099544395777ULL, 1099548327937ULL, 1099550556161ULL,
1099551080449ULL, 1099553308673ULL, 1099556192257ULL, 1099557765121ULL
2199028891649, 1099512938497, 1099557765121
};

static vector<uint64_t> DEMO_P = {
2199028891649ULL, 2199030071297ULL, 2199031382017ULL
2199031382017
};

int main(int argc, char** argv) {
const char* prefix = (argc > 1) ? argv[1] : "ckks_demo_context";
printf("CKKS Server Prep (%s)\n", prefix);

try {
uint64_t poly_degree = 65536;
uint64_t poly_degree = 2048;
uint64_t scale = 1ULL << 41;

Parameters params(scale, static_cast<int>(poly_degree), DEMO_Q, DEMO_P);
Expand Down