Skip to content

Commit

Permalink
minor modifications
Browse files Browse the repository at this point in the history
1. fix minor bugs
2. merge the latest branch
  • Loading branch information
zhangwfjh committed Mar 19, 2024
1 parent 20a7e45 commit 5a95b9b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion psi/psi/core/kmprt17_mp_psi/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ psi_cc_library(
],
deps = [
"//psi/psi/core:communication",
"//psi/psi/utils",
"//psi/psi/utils:sync",
"//psi/psi/utils:test_utils",
"@com_google_absl//absl/types:span",
"@yacl//yacl/base:exception",
Expand Down
5 changes: 2 additions & 3 deletions psi/psi/core/kmprt17_mp_psi/kmprt17_hashing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
#include <utility>

#include "yacl/base/exception.h"
#include "yacl/crypto/utils/rand.h"

namespace psi::psi {

void KmprtCuckooHashing::Insert(uint128_t elem) {
static std::random_device rd{};
static std::uniform_int_distribution<uint8_t> unif_hash_id{0, 5};
auto insert_into = [this, &elem](uint8_t c) {
for (uint8_t retry{}; retry != 128 && elem != NONE; ++retry) {
uint8_t rand_idx = unif_hash_id(rd) % num_hashes_[c];
uint8_t rand_idx = yacl::crypto::FastRandU64() % num_hashes_[c];
uint8_t idx = (rand_idx + 1) % num_hashes_[c];
size_t addr;
do {
Expand Down
6 changes: 3 additions & 3 deletions psi/psi/core/kmprt17_mp_psi/kmprt17_mp_psi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "psi/psi/core/communication.h"
#include "psi/psi/core/kmprt17_mp_psi/kmprt17_opprf.h"
#include "psi/psi/utils/utils.h"
#include "psi/psi/utils/sync.h"

namespace psi::psi {

Expand Down Expand Up @@ -82,9 +82,9 @@ auto KmprtParty::ZeroSharing(size_t count) const -> std::vector<Share> {
auto [ctx, wsize, me, leader] = CollectContext();
std::vector<Share> shares(wsize, Share(count));
for (size_t k{}; k != count; ++k) {
uint128_t sum{};
uint64_t sum{};
for (size_t dst{1}; dst != wsize; ++dst) {
sum ^= shares[dst][k] = yacl::crypto::FastRandU128();
sum ^= shares[dst][k] = yacl::crypto::FastRandU64();
}
shares[0][k] = sum;
}
Expand Down
13 changes: 6 additions & 7 deletions psi/psi/core/kmprt17_mp_psi/kmprt17_opprf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ std::vector<uint64_t> KmprtOpprfRecv(
evals.reserve(num_ot);
size_t ot_idx{}, b{};
std::array<uint64_t, BATCH_SIZE> batch_evals;
// Step 2. For each bin, invoke single-query OPPRF
// Step 2. For each bin, invokes single-query OPPRF
for (uint8_t c{}; c != 2; ++c) {
size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]};
size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot};
for (size_t addr{}; addr != hashing.num_bins_[c]; ++addr, ++ot_idx) {
auto elem = hashing.GetBin(c, addr);
elem == KmprtCuckooHashing::NONE && (elem = yc::FastRandU128());
receiver.Encode(ot_idx, elem,
absl::Span{reinterpret_cast<uint8_t*>(&batch_evals[b++]),
sizeof(uint64_t)});
receiver.Encode(
ot_idx, elem,
{reinterpret_cast<uint8_t*>(&batch_evals[b++]), sizeof(uint64_t)});
if (auto batch_size = (ot_idx - ot_begin) % BATCH_SIZE + 1;
batch_size == BATCH_SIZE || ot_idx + 1 == ot_end) {
b = 0;
Expand Down Expand Up @@ -144,7 +144,7 @@ void KmprtOpprfSend(const std::shared_ptr<yacl::link::Context>& ctx,
}
size_t ot_idx{};
auto evaluator = sender.GetOprf();
// Step 2. For each bin, invoke single-query OPPRF
// Step 2. For each bin, invokes single-query OPPRF
for (uint8_t c{}; c != 2; ++c) {
size_t ot_begin{c == uint8_t{0} ? 0 : bin_sizes[0]};
size_t ot_end{c == uint8_t{0} ? bin_sizes[0] : num_ot};
Expand All @@ -165,8 +165,7 @@ void KmprtOpprfSend(const std::shared_ptr<yacl::link::Context>& ctx,
for (auto it = bin.cbegin(); it != bin.cend(); ++it) {
uint64_t eval = evaluator->Eval(ot_idx, it->first);
auto index =
ro.Gen<size_t>(absl::MakeSpan(reinterpret_cast<uint8_t*>(&eval),
sizeof eval),
ro.Gen<size_t>({reinterpret_cast<uint8_t*>(&eval), sizeof eval},
nonce) %
table.size();
if (table[index] != uint64_t{0}) {
Expand Down

0 comments on commit 5a95b9b

Please sign in to comment.