From 960c380727e43ba900a219ab876cb017c8f8762b Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Tue, 7 Dec 2021 12:41:51 -0600 Subject: [PATCH] Squash --- src/Makefile.am | 41 +- src/Makefile.test.include | 13 +- src/libspark/aead.cpp | 89 +++ src/libspark/aead.h | 29 + src/libspark/bpplus.cpp | 473 ++++++++++++++++ src/libspark/bpplus.h | 33 ++ src/libspark/bpplus_proof.h | 44 ++ src/libspark/chaum.cpp | 163 ++++++ src/libspark/chaum.h | 45 ++ src/libspark/chaum_proof.h | 32 ++ src/libspark/coin.cpp | 154 +++++ src/libspark/coin.h | 111 ++++ src/libspark/grootle.cpp | 564 +++++++++++++++++++ src/libspark/grootle.h | 53 ++ src/libspark/grootle_proof.h | 45 ++ src/libspark/hash.cpp | 148 +++++ src/libspark/hash.h | 25 + src/libspark/kdf.cpp | 57 ++ src/libspark/kdf.h | 22 + src/libspark/keys.cpp | 118 ++++ src/libspark/keys.h | 72 +++ src/libspark/mint_transaction.cpp | 43 ++ src/libspark/mint_transaction.h | 30 + src/libspark/params.cpp | 137 +++++ src/libspark/params.h | 67 +++ src/libspark/schnorr.cpp | 40 ++ src/libspark/schnorr.h | 22 + src/libspark/schnorr_proof.h | 27 + src/libspark/spend_transaction.cpp | 317 +++++++++++ src/libspark/spend_transaction.h | 68 +++ src/libspark/test/aead_test.cpp | 128 +++++ src/libspark/test/bpplus_test.cpp | 194 +++++++ src/libspark/test/chaum_test.cpp | 180 ++++++ src/libspark/test/coin_test.cpp | 107 ++++ src/libspark/test/encrypt_test.cpp | 52 ++ src/libspark/test/grootle_test.cpp | 153 +++++ src/libspark/test/mint_transaction_test.cpp | 44 ++ src/libspark/test/schnorr_test.cpp | 85 +++ src/libspark/test/spend_transaction_test.cpp | 110 ++++ src/libspark/test/transcript_test.cpp | 174 ++++++ src/libspark/transcript.cpp | 177 ++++++ src/libspark/transcript.h | 32 ++ src/libspark/util.cpp | 232 ++++++++ src/libspark/util.h | 84 +++ 44 files changed, 4832 insertions(+), 2 deletions(-) create mode 100644 src/libspark/aead.cpp create mode 100644 src/libspark/aead.h create mode 100644 src/libspark/bpplus.cpp create mode 100644 src/libspark/bpplus.h create mode 100644 src/libspark/bpplus_proof.h create mode 100644 src/libspark/chaum.cpp create mode 100644 src/libspark/chaum.h create mode 100644 src/libspark/chaum_proof.h create mode 100644 src/libspark/coin.cpp create mode 100644 src/libspark/coin.h create mode 100644 src/libspark/grootle.cpp create mode 100644 src/libspark/grootle.h create mode 100644 src/libspark/grootle_proof.h create mode 100644 src/libspark/hash.cpp create mode 100644 src/libspark/hash.h create mode 100644 src/libspark/kdf.cpp create mode 100644 src/libspark/kdf.h create mode 100644 src/libspark/keys.cpp create mode 100644 src/libspark/keys.h create mode 100644 src/libspark/mint_transaction.cpp create mode 100644 src/libspark/mint_transaction.h create mode 100644 src/libspark/params.cpp create mode 100644 src/libspark/params.h create mode 100644 src/libspark/schnorr.cpp create mode 100644 src/libspark/schnorr.h create mode 100644 src/libspark/schnorr_proof.h create mode 100644 src/libspark/spend_transaction.cpp create mode 100644 src/libspark/spend_transaction.h create mode 100644 src/libspark/test/aead_test.cpp create mode 100644 src/libspark/test/bpplus_test.cpp create mode 100644 src/libspark/test/chaum_test.cpp create mode 100644 src/libspark/test/coin_test.cpp create mode 100644 src/libspark/test/encrypt_test.cpp create mode 100644 src/libspark/test/grootle_test.cpp create mode 100644 src/libspark/test/mint_transaction_test.cpp create mode 100644 src/libspark/test/schnorr_test.cpp create mode 100644 src/libspark/test/spend_transaction_test.cpp create mode 100644 src/libspark/test/transcript_test.cpp create mode 100644 src/libspark/transcript.cpp create mode 100644 src/libspark/transcript.h create mode 100644 src/libspark/util.cpp create mode 100644 src/libspark/util.h diff --git a/src/Makefile.am b/src/Makefile.am index fdeab1406c..1b1cbee591 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -55,6 +55,7 @@ LIBBITCOIN_CONSENSUS=libbitcoin_consensus.a LIBBITCOIN_CLI=libbitcoin_cli.a LIBBITCOIN_UTIL=libbitcoin_util.a LIBLELANTUS=liblelantus.a +LIBSPARK=libspark.a LIBBITCOIN_CRYPTO=crypto/libbitcoin_crypto.a LIBBITCOINQT=qt/libfiroqt.a LIBSECP256K1=secp256k1/libsecp256k1.la @@ -86,7 +87,8 @@ EXTRA_LIBRARIES += \ $(LIBBITCOIN_WALLET) \ $(LIBBITCOIN_ZMQ) \ $(LIBFIRO_SIGMA) \ - $(LIBLELANTUS) + $(LIBLELANTUS) \ + $(LIBSPARK) lib_LTLIBRARIES = $(LIBBITCOINCONSENSUS) @@ -625,6 +627,42 @@ libbitcoin_util_a_SOURCES = \ crypto/MerkleTreeProof/merkle-tree.cpp \ $(BITCOIN_CORE_H) +libspark_a_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +libspark_a_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +libspark_a_SOURCES = \ + libspark/transcript.h \ + libspark/transcript.cpp \ + libspark/params.h \ + libspark/params.cpp \ + libspark/schnorr_proof.h \ + libspark/schnorr.h \ + libspark/schnorr.cpp \ + libspark/chaum_proof.h \ + libspark/chaum.h \ + libspark/chaum.cpp \ + libspark/coin.h \ + libspark/coin.cpp \ + libspark/bpplus_proof.h \ + libspark/bpplus.h \ + libspark/bpplus.cpp \ + libspark/grootle_proof.h \ + libspark/grootle.h \ + libspark/grootle.cpp \ + libspark/keys.h \ + libspark/keys.cpp \ + libspark/util.h \ + libspark/util.cpp \ + libspark/aead.h \ + libspark/aead.cpp \ + libspark/kdf.h \ + libspark/kdf.cpp \ + libspark/hash.h \ + libspark/hash.cpp \ + libspark/mint_transaction.h \ + libspark/mint_transaction.cpp \ + libspark/spend_transaction.h \ + libspark/spend_transaction.cpp + liblelantus_a_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) liblelantus_a_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) liblelantus_a_SOURCES = \ @@ -729,6 +767,7 @@ firod_LDADD = \ $(LIBBITCOIN_WALLET) \ $(LIBFIRO_SIGMA) \ $(LIBLELANTUS) \ + $(LIBSPARK) \ $(LIBBITCOIN_ZMQ) \ $(LIBBITCOIN_CONSENSUS) \ $(LIBBITCOIN_CRYPTO) \ diff --git a/src/Makefile.test.include b/src/Makefile.test.include index 2e1e951287..27923c3a78 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -90,6 +90,16 @@ BITCOIN_TESTS = \ liblelantus/test/schnorr_test.cpp \ liblelantus/test/serialize_test.cpp \ liblelantus/test/sigma_extended_test.cpp \ + libspark/test/transcript_test.cpp \ + libspark/test/schnorr_test.cpp \ + libspark/test/chaum_test.cpp \ + libspark/test/bpplus_test.cpp \ + libspark/test/grootle_test.cpp \ + libspark/test/aead_test.cpp \ + libspark/test/encrypt_test.cpp \ + libspark/test/coin_test.cpp \ + libspark/test/mint_transaction_test.cpp \ + libspark/test/spend_transaction_test.cpp \ sigma/test/coin_spend_tests.cpp \ sigma/test/coin_tests.cpp \ sigma/test/primitives_tests.cpp \ @@ -199,7 +209,7 @@ test_test_bitcoin_LDADD = $(LIBBITCOIN_SERVER) -ltor test_test_bitcoin_SOURCES = $(BITCOIN_TESTS) $(JSON_TEST_FILES) $(RAW_TEST_FILES) test_test_bitcoin_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) -I$(builddir)/test/ $(TESTDEFS) $(EVENT_CFLAGS) -test_test_bitcoin_LDADD += $(LIBBITCOIN_CLI) $(LIBBITCOIN_COMMON) $(LIBBITCOIN_UTIL) $(LIBBITCOIN_CONSENSUS) $(LIBBITCOIN_CRYPTO) $(LIBFIRO_SIGMA) $(LIBLELANTUS) $(LIBUNIVALUE) $(LIBLEVELDB) $(LIBLEVELDB_SSE42) $(LIBMEMENV) \ +test_test_bitcoin_LDADD += $(LIBBITCOIN_CLI) $(LIBBITCOIN_COMMON) $(LIBBITCOIN_UTIL) $(LIBBITCOIN_CONSENSUS) $(LIBBITCOIN_CRYPTO) $(LIBFIRO_SIGMA) $(LIBLELANTUS) $(LIBSPARK) $(LIBUNIVALUE) $(LIBLEVELDB) $(LIBLEVELDB_SSE42) $(LIBMEMENV) \ $(BACKTRACE_LIB) $(BOOST_LIBS) $(BOOST_UNIT_TEST_FRAMEWORK_LIB) $(LIBSECP256K1) $(EVENT_PTHREADS_LIBS) $(ZMQ_LIBS) $(ZLIB_LIBS) test_test_bitcoin_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) if ENABLE_WALLET @@ -226,6 +236,7 @@ test_test_bitcoin_fuzzy_LDADD = \ $(LIBUNIVALUE) \ $(LIBBITCOIN_SERVER) \ $(LIBLELANTUS) \ + $(LIBSPARK) \ $(LIBBITCOIN_COMMON) \ $(LIBBITCOIN_UTIL) \ $(LIBBITCOIN_CONSENSUS) \ diff --git a/src/libspark/aead.cpp b/src/libspark/aead.cpp new file mode 100644 index 0000000000..d3686e255f --- /dev/null +++ b/src/libspark/aead.cpp @@ -0,0 +1,89 @@ +#include "aead.h" + +namespace spark { + +// Perform authenticated encryption with ChaCha20-Poly1305 +AEADEncryptedData AEAD::encrypt(const std::vector& key, const std::string additional_data, CDataStream& data) { + // Check key size + if (key.size() != AEAD_KEY_SIZE) { + throw std::invalid_argument("Bad AEAD key size"); + } + + // Set up the result structure + AEADEncryptedData result; + + // Internal size tracker; we know the size of the data already, and can ignore + int TEMP; + + // For our application, we can safely use a zero nonce since keys are never reused + std::vector iv; + iv.resize(AEAD_IV_SIZE); + + // Set up the cipher + EVP_CIPHER_CTX* ctx; + ctx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(ctx, EVP_chacha20_poly1305(), NULL, key.data(), iv.data()); + + // Include the associated data + std::vector additional_data_bytes(additional_data.begin(), additional_data.end()); + EVP_EncryptUpdate(ctx, NULL, &TEMP, additional_data_bytes.data(), additional_data_bytes.size()); + + // Encrypt the plaintext + result.ciphertext.resize(data.size()); + EVP_EncryptUpdate(ctx, result.ciphertext.data(), &TEMP, reinterpret_cast(data.data()), data.size()); + EVP_EncryptFinal_ex(ctx, NULL, &TEMP); + + // Get the tag + result.tag.resize(AEAD_TAG_SIZE); + EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, AEAD_TAG_SIZE, result.tag.data()); + + // Clean up + EVP_CIPHER_CTX_free(ctx); + + return result; +} + +// Perform authenticated decryption with ChaCha20-Poly1305 +CDataStream AEAD::decrypt_and_verify(const std::vector& key, const std::string additional_data, AEADEncryptedData& data) { + // Check key size + if (key.size() != AEAD_KEY_SIZE) { + throw std::invalid_argument("Bad AEAD key size"); + } + + // Set up the result + CDataStream result(SER_NETWORK, PROTOCOL_VERSION); + + // Internal size tracker; we know the size of the data already, and can ignore + int TEMP; + + // For our application, we can safely use a zero nonce since keys are never reused + std::vector iv; + iv.resize(AEAD_IV_SIZE); + + // Set up the cipher + EVP_CIPHER_CTX* ctx; + ctx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(ctx, EVP_chacha20_poly1305(), NULL, key.data(), iv.data()); + + // Include the associated data + std::vector additional_data_bytes(additional_data.begin(), additional_data.end()); + EVP_DecryptUpdate(ctx, NULL, &TEMP, additional_data_bytes.data(), additional_data_bytes.size()); + + // Decrypt the ciphertext + result.resize(data.ciphertext.size()); + EVP_DecryptUpdate(ctx, reinterpret_cast(result.data()), &TEMP, data.ciphertext.data(), data.ciphertext.size()); + + // Set the expected tag + EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, AEAD_TAG_SIZE, data.tag.data()); + + // Decrypt and clean up + int ret = EVP_DecryptFinal_ex(ctx, NULL, &TEMP); + EVP_CIPHER_CTX_free(ctx); + if (ret != 1) { + throw std::runtime_error("Bad AEAD authentication"); + } + + return result; +} + +} diff --git a/src/libspark/aead.h b/src/libspark/aead.h new file mode 100644 index 0000000000..e7af8ba926 --- /dev/null +++ b/src/libspark/aead.h @@ -0,0 +1,29 @@ +#ifndef FIRO_SPARK_AEAD_H +#define FIRO_SPARK_AEAD_H +#include +#include "util.h" + +namespace spark { + +struct AEADEncryptedData { + std::vector ciphertext; + std::vector tag; + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(ciphertext); + READWRITE(tag); + } +}; + +class AEAD { +public: + static AEADEncryptedData encrypt(const std::vector& key, const std::string additional_data, CDataStream& data); + static CDataStream decrypt_and_verify(const std::vector& key, const std::string associated_data, AEADEncryptedData& data); +}; + +} + +#endif diff --git a/src/libspark/bpplus.cpp b/src/libspark/bpplus.cpp new file mode 100644 index 0000000000..224b7a57ef --- /dev/null +++ b/src/libspark/bpplus.cpp @@ -0,0 +1,473 @@ +#include "bpplus.h" +#include "transcript.h" + +namespace spark { + +// Useful scalar constants +const Scalar ZERO = Scalar((uint64_t) 0); +const Scalar ONE = Scalar((uint64_t) 1); +const Scalar TWO = Scalar((uint64_t) 2); + +BPPlus::BPPlus( + const GroupElement& G_, + const GroupElement& H_, + const std::vector& Gi_, + const std::vector& Hi_, + const std::size_t N_) + : G (G_) + , H (H_) + , Gi (Gi_) + , Hi (Hi_) + , N (N_) +{ + if (Gi.size() != Hi.size()) { + throw std::invalid_argument("Bad BPPlus generator sizes!"); + } + + // Bit length must be a power of two + if ((N & (N - 1) != 0)) { + throw std::invalid_argument("Bad BPPlus bit length!"); + } + + // Compute 2**N-1 for optimized verification + TWO_N_MINUS_ONE = TWO; + for (int i = 0; i < log2(N); i++) { + TWO_N_MINUS_ONE *= TWO_N_MINUS_ONE; + } + TWO_N_MINUS_ONE -= ONE; +} + +static inline std::size_t log2(std::size_t n) { + std::size_t l = 0; + while ((n >>= 1) != 0) { + l++; + } + + return l; +} + +void BPPlus::prove( + const std::vector& v, + const std::vector& r, + const std::vector& C, + BPPlusProof& proof) { + // Check statement validity + std::size_t M = C.size(); + if (N*M > Gi.size()) { + throw std::invalid_argument("Bad BPPlus statement!"); + } + if (!(v.size() == M && r.size() == M)) { + throw std::invalid_argument("Bad BPPlus statement!"); + } + for (std::size_t j = 0; j < M; j++) { + if (!(G*v[j] + H*r[j] == C[j])) { + throw std::invalid_argument("Bad BPPlus statement!"); + } + } + + // Set up transcript + Transcript transcript("SPARK_BPPLUS"); + transcript.add("G", G); + transcript.add("H", H); + transcript.add("Gi", Gi); + transcript.add("Hi", Hi); + transcript.add("N", Scalar(N)); + transcript.add("C", C); + + // Decompose bits + std::vector> bits; + bits.resize(M); + for (std::size_t j = 0; j < M; j++) { + v[j].get_bits(bits[j]); + } + + // Compute aL, aR + std::vector aL, aR; + aL.reserve(N*M); + aR.reserve(N*M); + for (std::size_t j = 0; j < M; ++j) + { + for (std::size_t i = 1; i <= N; ++i) + { + aL.emplace_back(uint64_t(bits[j][bits[j].size() - i])); + aR.emplace_back(Scalar(uint64_t(bits[j][bits[j].size() - i])) - ONE); + } + } + + // Compute A + Scalar alpha; + alpha.randomize(); + + std::vector A_points; + std::vector A_scalars; + A_points.reserve(2*N*M + 1); + A_points.reserve(2*N*M + 1); + + A_points.emplace_back(H); + A_scalars.emplace_back(alpha); + for (std::size_t i = 0; i < N*M; i++) { + A_points.emplace_back(Gi[i]); + A_scalars.emplace_back(aL[i]); + A_points.emplace_back(Hi[i]); + A_scalars.emplace_back(aR[i]); + } + secp_primitives::MultiExponent A_multiexp(A_points, A_scalars); + proof.A = A_multiexp.get_multiple(); + transcript.add("A", proof.A); + + // Challenges + Scalar y = transcript.challenge("y"); + Scalar z = transcript.challenge("z"); + Scalar z_square = z.square(); + + // Challenge powers + std::vector y_powers; + y_powers.resize(M*N + 2); + y_powers[0] = ZERO; + y_powers[1] = y; + for (std::size_t i = 2; i < M*N + 2; i++) { + y_powers[i] = y_powers[i-1]*y; + } + + // Compute d + std::vector d; + d.resize(M*N); + d[0] = z_square; + for (std::size_t i = 1; i < N; i++) { + d[i] = TWO*d[i-1]; + } + for (std::size_t j = 1; j < M; j++) { + for (std::size_t i = 0; i < N; i++) { + d[j*N+i] = d[(j-1)*N+i]*z_square; + } + } + + // Compute aL1, aR1 + std::vector aL1, aR1; + for (std::size_t i = 0; i < N*M; i++) { + aL1.emplace_back(aL[i] - z); + aR1.emplace_back(aR[i] + d[i]*y_powers[N*M - i] + z); + } + + // Compute alpha1 + Scalar alpha1 = alpha; + Scalar z_even_powers = 1; + for (std::size_t j = 0; j < M; j++) { + z_even_powers *= z_square; + alpha1 += z_even_powers*r[j]*y_powers[N*M+1]; + } + + // Run the inner product rounds + std::vector Gi1(Gi); + std::vector Hi1(Hi); + std::vector a1(aL1); + std::vector b1(aR1); + std::size_t N1 = N*M; + + while (N1 > 1) { + N1 /= 2; + + Scalar dL, dR; + dL.randomize(); + dR.randomize(); + + // Compute cL, cR + Scalar cL, cR; + for (std::size_t i = 0; i < N1; i++) { + cL += a1[i]*y_powers[i+1]*b1[i+N1]; + cR += a1[i+N1]*y_powers[N1]*y_powers[i+1]*b1[i]; + } + + // Compute L, R + GroupElement L_, R_; + std::vector L_points, R_points; + std::vector L_scalars, R_scalars; + L_points.reserve(2*N1 + 2); + R_points.reserve(2*N1 + 2); + L_scalars.reserve(2*N1 + 2); + R_scalars.reserve(2*N1 + 2); + Scalar y_N1_inverse = y_powers[N1].inverse(); + for (std::size_t i = 0; i < N1; i++) { + L_points.emplace_back(Gi1[i+N1]); + L_scalars.emplace_back(a1[i]*y_N1_inverse); + L_points.emplace_back(Hi1[i]); + L_scalars.emplace_back(b1[i+N1]); + + R_points.emplace_back(Gi1[i]); + R_scalars.emplace_back(a1[i+N1]*y_powers[N1]); + R_points.emplace_back(Hi1[i+N1]); + R_scalars.emplace_back(b1[i]); + } + L_points.emplace_back(G); + L_scalars.emplace_back(cL); + L_points.emplace_back(H); + L_scalars.emplace_back(dL); + R_points.emplace_back(G); + R_scalars.emplace_back(cR); + R_points.emplace_back(H); + R_scalars.emplace_back(dR); + + secp_primitives::MultiExponent L_multiexp(L_points, L_scalars); + secp_primitives::MultiExponent R_multiexp(R_points, R_scalars); + L_ = L_multiexp.get_multiple(); + R_ = R_multiexp.get_multiple(); + proof.L.emplace_back(L_); + proof.R.emplace_back(R_); + + transcript.add("L", L_); + transcript.add("R", R_); + Scalar e = transcript.challenge("e"); + Scalar e_inverse = e.inverse(); + + // Compress round elements + for (std::size_t i = 0; i < N1; i++) { + Gi1[i] = Gi1[i]*e_inverse + Gi1[i+N1]*(e*y_N1_inverse); + Hi1[i] = Hi1[i]*e + Hi1[i+N1]*e_inverse; + a1[i] = a1[i]*e + a1[i+N1]*y_powers[N1]*e_inverse; + b1[i] = b1[i]*e_inverse + b1[i+N1]*e; + } + Gi1.resize(N1); + Hi1.resize(N1); + a1.resize(N1); + b1.resize(N1); + + // Update alpha1 + alpha1 = dL*e.square() + alpha1 + dR*e_inverse.square(); + } + + // Final proof elements + Scalar r_, s_, d_, eta_; + r_.randomize(); + s_.randomize(); + d_.randomize(); + eta_.randomize(); + + proof.A1 = Gi1[0]*r_ + Hi1[0]*s_ + G*(r_*y*b1[0] + s_*y*a1[0]) + H*d_; + proof.B = G*(r_*y*s_) + H*eta_; + + transcript.add("A1", proof.A1); + transcript.add("B", proof.B); + Scalar e1 = transcript.challenge("e1"); + + proof.r1 = r_ + a1[0]*e1; + proof.s1 = s_ + b1[0]*e1; + proof.d1 = eta_ + d_*e1 + alpha1*e1.square(); +} + +bool BPPlus::verify(const std::vector& C, const BPPlusProof& proof) { + std::vector> C_batch = {C}; + std::vector proof_batch = {proof}; + + return verify(C_batch, proof_batch); +} + +bool BPPlus::verify(const std::vector>& C, const std::vector& proofs) { + // Preprocess all proofs + if (!(C.size() == proofs.size())) { + return false; + } + std::size_t N_proofs = proofs.size(); + std::size_t max_M = 0; // maximum number of aggregated values across all proofs + + // Check aggregated input consistency + for (std::size_t k = 0; k < N_proofs; k++) { + std::size_t M = C[k].size(); + + // Require a power of two + if (M == 0) { + return false; + } + if ((M & (M - 1)) != 0) { + return false; + } + + // Track the maximum value + if (M > max_M) { + max_M = M; + } + + // Check inner produce round consistency + std::size_t rounds = proofs[k].L.size(); + if (proofs[k].R.size() != rounds) { + return false; + } + if (log2(N*M) != rounds) { + return false; + } + } + + // Check the bounds on the batch + if (max_M*N > Gi.size() || max_M*N > Hi.size()) { + return false; + } + + // Set up final multiscalar multiplication and common scalars + std::vector points; + std::vector scalars; + Scalar G_scalar, H_scalar; + + // Interleave the Gi and Hi scalars + for (std::size_t i = 0; i < max_M*N; i++) { + points.emplace_back(Gi[i]); + scalars.emplace_back(ZERO); + points.emplace_back(Hi[i]); + scalars.emplace_back(ZERO); + } + + // Process each proof and add to the batch + for (std::size_t k_proofs = 0; k_proofs < N_proofs; k_proofs++) { + const BPPlusProof proof = proofs[k_proofs]; + const std::size_t M = C[k_proofs].size(); + const std::size_t rounds = proof.L.size(); + + // Weight this proof in the batch + Scalar w = ZERO; + while (w == ZERO) { + w.randomize(); + } + + // Set up transcript + Transcript transcript("SPARK_BPPLUS"); + transcript.add("G", G); + transcript.add("H", H); + transcript.add("Gi", Gi); + transcript.add("Hi", Hi); + transcript.add("N", Scalar(N)); + transcript.add("C", C[k_proofs]); + transcript.add("A", proof.A); + + // Get challenges + Scalar y = transcript.challenge("y"); + Scalar y_inverse = y.inverse(); + Scalar y_NM = y; + for (std::size_t i = 0; i < rounds; i++) { + y_NM = y_NM.square(); + } + Scalar y_NM_1 = y_NM*y; + + Scalar z = transcript.challenge("z"); + Scalar z_square = z.square(); + + std::vector e; + std::vector e_inverse; + for (std::size_t j = 0; j < rounds; j++) { + transcript.add("L", proof.L[j]); + transcript.add("R", proof.R[j]); + e.emplace_back(transcript.challenge("e")); + e_inverse.emplace_back(e[j].inverse()); + } + + transcript.add("A1", proof.A1); + transcript.add("B", proof.B); + Scalar e1 = transcript.challenge("e1"); + Scalar e1_square = e1.square(); + + // C_j: -e1**2 * z**(2*(j + 1)) * y**(N*M + 1) * w + Scalar C_scalar = e1_square.negate()*z_square*y_NM_1*w; + for (std::size_t j = 0; j < M; j++) { + points.emplace_back(C[k_proofs][j]); + scalars.emplace_back(C_scalar); + + C_scalar *= z.square(); + } + + // B: -w + points.emplace_back(proof.B); + scalars.emplace_back(w.negate()); + + // A1: -w*e1 + points.emplace_back(proof.A1); + scalars.emplace_back(w.negate()*e1); + + // A: -w*e1**2 + points.emplace_back(proof.A); + scalars.emplace_back(w.negate()*e1_square); + + // H: w*d1 + H_scalar += w*proof.d1; + + // Compute d + std::vector d; + d.resize(N*M); + d[0] = z_square; + for (std::size_t i = 1; i < N; i++) { + d[i] = d[i-1] + d[i-1]; + } + for (std::size_t j = 1; j < M; j++) { + for (std::size_t i = 0; i < N; i++) { + d[j*N + i] = d[(j - 1)*N + i]*z_square; + } + } + + // Sum the elements of d + Scalar sum_d = z_square; + Scalar temp_z = sum_d; + std::size_t temp_2M = 2*M; + while (temp_2M > 2) { + sum_d += sum_d*temp_z; + temp_z = temp_z.square(); + temp_2M /= 2; + } + sum_d *= TWO_N_MINUS_ONE; + + // Sum the powers of y + Scalar sum_y; + Scalar track = y; + for (std::size_t i = 0; i < N*M; i++) { + sum_y += track; + track *= y; + } + + // G: w*(r1*y*s1 + e1**2*(y**(N*M + 1)*z*sum_d + (z**2-z)*sum_y)) + G_scalar += w*(proof.r1*y*proof.s1 + e1_square*(y_NM_1*z*sum_d + (z_square - z)*sum_y)); + + // Track some iterated exponential terms + Scalar iter_y_inv = ONE; // y.inverse()**i + Scalar iter_y_NM = y_NM; // y**(N*M - i) + + // Gi, Hi + for (std::size_t i = 0; i < N*M; i++) { + Scalar g = proof.r1*e1*iter_y_inv; + Scalar h = proof.s1*e1; + for (std::size_t j = 0; j < rounds; j++) { + if ((i >> j) & 1) { + g *= e[rounds-j-1]; + h *= e_inverse[rounds-j-1]; + } else { + h *= e[rounds-j-1]; + g *= e_inverse[rounds-j-1]; + } + } + + // Gi + scalars[2*i] += w*(g + e1_square*z); + + // Hi + scalars[2*i+1] += w*(h - e1_square*(d[i]*iter_y_NM+z)); + + // Update the iterated values + iter_y_inv *= y_inverse; + iter_y_NM *= y_inverse; + } + + // L, R + for (std::size_t j = 0; j < rounds; j++) { + points.emplace_back(proof.L[j]); + scalars.emplace_back(w*(e1_square.negate()*e[j].square())); + points.emplace_back(proof.R[j]); + scalars.emplace_back(w*(e1_square.negate()*e_inverse[j].square())); + } + } + + // Add the common generators + points.emplace_back(G); + scalars.emplace_back(G_scalar); + points.emplace_back(H); + scalars.emplace_back(H_scalar); + + // Test the batch + secp_primitives::MultiExponent multiexp(points, scalars); + return multiexp.get_multiple().isInfinity(); +} + +} \ No newline at end of file diff --git a/src/libspark/bpplus.h b/src/libspark/bpplus.h new file mode 100644 index 0000000000..9c6983ec7f --- /dev/null +++ b/src/libspark/bpplus.h @@ -0,0 +1,33 @@ +#ifndef FIRO_LIBSPARK_BPPLUS_H +#define FIRO_LIBSPARK_BPPLUS_H + +#include "bpplus_proof.h" +#include + +namespace spark { + +class BPPlus { +public: + BPPlus( + const GroupElement& G, + const GroupElement& H, + const std::vector& Gi, + const std::vector& Hi, + const std::size_t N); + + void prove(const std::vector& v, const std::vector& r, const std::vector& C, BPPlusProof& proof); + bool verify(const std::vector& C, const BPPlusProof& proof); // single proof + bool verify(const std::vector>& C, const std::vector& proofs); // batch of proofs + +private: + GroupElement G; + GroupElement H; + std::vector Gi; + std::vector Hi; + std::size_t N; + Scalar TWO_N_MINUS_ONE; +}; + +} + +#endif diff --git a/src/libspark/bpplus_proof.h b/src/libspark/bpplus_proof.h new file mode 100644 index 0000000000..214dbc9a0e --- /dev/null +++ b/src/libspark/bpplus_proof.h @@ -0,0 +1,44 @@ +#ifndef FIRO_LIBSPARK_BPPLUS_PROOF_H +#define FIRO_LIBSPARK_BPPLUS_PROOF_H + +#include "params.h" + +namespace spark { + +class BPPlusProof{ +public: + + static inline int int_log2(std::size_t number) { + assert(number != 0); + + int l2 = 0; + while ((number >>= 1) != 0) + l2++; + + return l2; + } + + inline std::size_t memoryRequired() const { + return 3*GroupElement::memoryRequired() + 3*Scalar::memoryRequired() + L.size()*GroupElement::memoryRequired() + R.size()*GroupElement::memoryRequired(); + } + + ADD_SERIALIZE_METHODS; + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(A); + READWRITE(A1); + READWRITE(B); + READWRITE(r1); + READWRITE(s1); + READWRITE(d1); + READWRITE(L); + READWRITE(R); + } + + GroupElement A, A1, B; + Scalar r1, s1, d1; + std::vector L, R; +}; +} + +#endif diff --git a/src/libspark/chaum.cpp b/src/libspark/chaum.cpp new file mode 100644 index 0000000000..e130ef2539 --- /dev/null +++ b/src/libspark/chaum.cpp @@ -0,0 +1,163 @@ +#include "chaum.h" +#include "transcript.h" + +namespace spark { + +Chaum::Chaum(const GroupElement& F_, const GroupElement& G_, const GroupElement& H_, const GroupElement& U_): + F(F_), G(G_), H(H_), U(U_) { +} + +Scalar Chaum::challenge( + const Scalar& mu, + const std::vector& S, + const std::vector& T, + const GroupElement& A1, + const std::vector& A2 +) { + Transcript transcript("SPARK_CHAUM"); + transcript.add("F", F); + transcript.add("G", G); + transcript.add("H", H); + transcript.add("U", U); + transcript.add("mu", mu); + transcript.add("S", S); + transcript.add("T", T); + transcript.add("A1", A1); + transcript.add("A2", A2); + + return transcript.challenge("c"); +} + +void Chaum::prove( + const Scalar& mu, + const std::vector& x, + const std::vector& y, + const std::vector& z, + const std::vector& S, + const std::vector& T, + ChaumProof& proof +) { + // Check statement validity + std::size_t n = x.size(); + if (!(y.size() == n && z.size() == n && S.size() == n && T.size() == n)) { + throw std::invalid_argument("Bad Chaum statement!"); + } + for (std::size_t i = 0; i < n; i++) { + if (!(F*x[i] + G*y[i] + H*z[i] == S[i] && T[i]*x[i] + G*y[i] == U)) { + throw std::invalid_argument("Bad Chaum statement!"); + } + } + + std::vector r; + r.resize(n); + std::vector s; + s.resize(n); + for (std::size_t i = 0; i < n; i++) { + r[i].randomize(); + s[i].randomize(); + } + Scalar t; + t.randomize(); + + proof.A1 = H*t; + proof.A2.resize(n); + for (std::size_t i = 0; i < n; i++) { + proof.A1 += F*r[i] + G*s[i]; + proof.A2[i] = T[i]*r[i] + G*s[i]; + } + + Scalar c = challenge(mu, S, T, proof.A1, proof.A2); + + proof.t1.resize(n); + proof.t3 = t; + Scalar c_power(c); + for (std::size_t i = 0; i < n; i++) { + proof.t1[i] = r[i] + c_power*x[i]; + proof.t2 += s[i] + c_power*y[i]; + proof.t3 += c_power*z[i]; + c_power *= c; + } +} + +bool Chaum::verify( + const Scalar& mu, + const std::vector& S, + const std::vector& T, + ChaumProof& proof +) { + // Check proof semantics + std::size_t n = S.size(); + if (!(T.size() == n && proof.A2.size() == n && proof.t1.size() == n)) { + throw std::invalid_argument("Bad Chaum semantics!"); + } + + Scalar c = challenge(mu, S, T, proof.A1, proof.A2); + std::vector c_powers; + c_powers.emplace_back(c); + for (std::size_t i = 1; i < n; i++) { + c_powers.emplace_back(c_powers[i-1]*c); + } + + // Weight the verification equations + Scalar w; + while (w.isZero()) { + w.randomize(); + } + + std::vector scalars; + std::vector points; + scalars.reserve(3*n + 5); + points.reserve(3*n + 5); + + // F + Scalar F_scalar; + for (std::size_t i = 0; i < n; i++) { + F_scalar -= proof.t1[i]; + } + scalars.emplace_back(F_scalar); + points.emplace_back(F); + + // G + scalars.emplace_back(proof.t2.negate() - w*proof.t2); + points.emplace_back(G); + + // H + scalars.emplace_back(proof.t3.negate()); + points.emplace_back(H); + + // U + Scalar U_scalar; + for (std::size_t i = 0; i < n; i++) { + U_scalar += c_powers[i]; + } + U_scalar *= w; + scalars.emplace_back(U_scalar); + points.emplace_back(U); + + // A1 + scalars.emplace_back(Scalar((uint64_t) 1)); + points.emplace_back(proof.A1); + + // {A2} + for (std::size_t i = 0; i < n; i++) { + scalars.emplace_back(w); + points.emplace_back(proof.A2[i]); + } + + // {S} + for (std::size_t i = 0; i < n; i++) { + scalars.emplace_back(c_powers[i]); + points.emplace_back(S[i]); + } + + // {T} + for (std::size_t i = 0; i < n; i++) { + scalars.emplace_back(w.negate()*proof.t1[i]); + points.emplace_back(T[i]); + } + + secp_primitives::MultiExponent multiexp(points, scalars); + return multiexp.get_multiple().isInfinity(); +} + +} diff --git a/src/libspark/chaum.h b/src/libspark/chaum.h new file mode 100644 index 0000000000..b15868b84c --- /dev/null +++ b/src/libspark/chaum.h @@ -0,0 +1,45 @@ +#ifndef FIRO_LIBSPARK_CHAUM_H +#define FIRO_LIBSPARK_CHAUM_H + +#include "chaum_proof.h" +#include + +namespace spark { + +class Chaum { +public: + Chaum(const GroupElement& F, const GroupElement& G, const GroupElement& H, const GroupElement& U); + + void prove( + const Scalar& mu, + const std::vector& x, + const std::vector& y, + const std::vector& z, + const std::vector& S, + const std::vector& T, + ChaumProof& proof + ); + bool verify( + const Scalar& mu, + const std::vector& S, + const std::vector& T, + ChaumProof& proof + ); + +private: + Scalar challenge( + const Scalar& mu, + const std::vector& S, + const std::vector& T, + const GroupElement& A1, + const std::vector& A2 + ); + const GroupElement& F; + const GroupElement& G; + const GroupElement& H; + const GroupElement& U; +}; + +} + +#endif diff --git a/src/libspark/chaum_proof.h b/src/libspark/chaum_proof.h new file mode 100644 index 0000000000..1885b883c1 --- /dev/null +++ b/src/libspark/chaum_proof.h @@ -0,0 +1,32 @@ +#ifndef FIRO_LIBSPARK_CHAUM_PROOF_H +#define FIRO_LIBSPARK_CHAUM_PROOF_H + +#include "params.h" + +namespace spark { + +class ChaumProof{ +public: + inline std::size_t memoryRequired() const { + return GroupElement::memoryRequired() + A2.size()*GroupElement::memoryRequired() + t1.size()*Scalar::memoryRequired() + 2*Scalar::memoryRequired(); + } + + ADD_SERIALIZE_METHODS; + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(A1); + READWRITE(A2); + READWRITE(t1); + READWRITE(t2); + READWRITE(t3); + } + +public: + GroupElement A1; + std::vector A2; + std::vector t1; + Scalar t2, t3; +}; +} + +#endif diff --git a/src/libspark/coin.cpp b/src/libspark/coin.cpp new file mode 100644 index 0000000000..66644808ee --- /dev/null +++ b/src/libspark/coin.cpp @@ -0,0 +1,154 @@ +#include "coin.h" + +namespace spark { + +using namespace secp_primitives; + +Coin::Coin() {} + +Coin::Coin( + const Params* params, + const char type, + const Scalar& k, + const Address& address, + const uint64_t v, + const std::string memo +) { + this->params = params; + + // Validate the type + if (type != COIN_TYPE_MINT && type != COIN_TYPE_SPEND) { + throw std::invalid_argument("Bad coin type"); + } + this->type = type; + + + // + // Common elements to both coin types + // + + // Construct the recovery key + this->K = SparkUtils::hash_div(address.get_d())*SparkUtils::hash_k(k); + + // Construct the serial commitment + this->S = this->params->get_F()*SparkUtils::hash_ser(k) + address.get_Q2(); + + // Construct the value commitment + this->C = this->params->get_G()*Scalar(v) + this->params->get_H()*SparkUtils::hash_val(k); + + // Check the memo validity, and pad if needed + if (memo.size() > this->params->get_memo_bytes()) { + throw std::invalid_argument("Memo is too large"); + } + std::vector memo_bytes(memo.begin(), memo.end()); + std::vector padded_memo(memo_bytes); + padded_memo.resize(this->params->get_memo_bytes()); + + // + // Type-specific elements + // + + if (this->type == COIN_TYPE_MINT) { + this->v = v; + + // Encrypt recipient data + MintCoinRecipientData r; + r.d = address.get_d(); + r.k = k; + r.memo = std::string(memo.begin(), memo.end()); + CDataStream r_stream(SER_NETWORK, PROTOCOL_VERSION); + r_stream << r; + this->r_ = AEAD::encrypt(SparkUtils::kdf_aead(address.get_Q1()*SparkUtils::hash_k(k)), "Mint coin data", r_stream); + } else { + // Encrypt recipient data + SpendCoinRecipientData r; + r.v = v; + r.d = address.get_d(); + r.k = k; + r.memo = std::string(memo.begin(), memo.end()); + CDataStream r_stream(SER_NETWORK, PROTOCOL_VERSION); + r_stream << r; + this->r_ = AEAD::encrypt(SparkUtils::kdf_aead(address.get_Q1()*SparkUtils::hash_k(k)), "Spend coin data", r_stream); + } +} + +// Validate a coin for identification +bool Coin::validate( + const IncomingViewKey& incoming_view_key, + IdentifiedCoinData& data +) { + // Check recovery key + if (SparkUtils::hash_div(data.d)*SparkUtils::hash_k(data.k) != this->K) { + return false; + } + + // Check value commitment + if (this->params->get_G()*Scalar(data.v) + this->params->get_H()*SparkUtils::hash_val(data.k) != this->C) { + return false; + } + + // Check serial commitment + data.i = incoming_view_key.get_diversifier(data.d); + + if (this->params->get_F()*(SparkUtils::hash_ser(data.k) + SparkUtils::hash_Q2(incoming_view_key.get_s1(), data.i)) + incoming_view_key.get_P2() != this->S) { + return false; + } + + return true; +} + +// Recover a coin +RecoveredCoinData Coin::recover(const FullViewKey& full_view_key, const IdentifiedCoinData& data) { + RecoveredCoinData recovered_data; + recovered_data.s = SparkUtils::hash_ser(data.k) + SparkUtils::hash_Q2(full_view_key.get_s1(), data.i) + full_view_key.get_s2(); + recovered_data.T = (this->params->get_U() + full_view_key.get_D().inverse())*recovered_data.s.inverse(); + + return recovered_data; +} + +// Identify a coin +IdentifiedCoinData Coin::identify(const IncomingViewKey& incoming_view_key) { + IdentifiedCoinData data; + + // Deserialization means this process depends on the coin type + if (this->type == COIN_TYPE_MINT) { + MintCoinRecipientData r; + + try { + // Decrypt recipient data + CDataStream stream = AEAD::decrypt_and_verify(SparkUtils::kdf_aead(this->K*incoming_view_key.get_s1()), "Mint coin data", this->r_); + stream >> r; + } catch (...) { + throw std::runtime_error("Unable to identify coin"); + } + + data.d = r.d; + data.v = this->v; + data.k = r.k; + data.memo = r.memo; + } else { + SpendCoinRecipientData r; + + try { + // Decrypt recipient data + CDataStream stream = AEAD::decrypt_and_verify(SparkUtils::kdf_aead(this->K*incoming_view_key.get_s1()), "Spend coin data", this->r_); + stream >> r; + } catch (...) { + throw std::runtime_error("Unable to identify coin"); + } + + data.d = r.d; + data.v = r.v; + data.k = r.k; + data.memo = r.memo; + } + + // Validate the coin + if (!validate(incoming_view_key, data)) { + throw std::runtime_error("Malformed coin"); + } + + return data; +} + +} diff --git a/src/libspark/coin.h b/src/libspark/coin.h new file mode 100644 index 0000000000..1d935b14a0 --- /dev/null +++ b/src/libspark/coin.h @@ -0,0 +1,111 @@ +#ifndef FIRO_SPARK_COIN_H +#define FIRO_SPARK_COIN_H +#include "bpplus.h" +#include "keys.h" +#include +#include "params.h" +#include "aead.h" +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +// Flags for coin types: those generated from mints, and those generated from spends +const char COIN_TYPE_MINT = 0; +const char COIN_TYPE_SPEND = 1; + +struct IdentifiedCoinData { + uint64_t i; // diversifier + std::vector d; // encrypted diversifier + uint64_t v; // value + Scalar k; // nonce + std::string memo; // memo +}; + +struct RecoveredCoinData { + Scalar s; // serial + GroupElement T; // tag +}; + +// Data to be encrypted for the recipient of a coin generated in a mint transaction +struct MintCoinRecipientData { + std::vector d; // encrypted diversifier + Scalar k; // nonce + std::string memo; // memo + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(d); + READWRITE(k); + READWRITE(memo); + } +}; + +// Data to be encrypted for the recipient of a coin generated in a spend transaction +struct SpendCoinRecipientData { + uint64_t v; // value + std::vector d; // encrypted diversifier + Scalar k; // nonce + std::string memo; // memo + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(v); + READWRITE(d); + READWRITE(k); + READWRITE(memo); + } +}; + +class Coin { +public: + Coin(); + Coin( + const Params* params, + const char type, + const Scalar& k, + const Address& address, + const uint64_t v, + const std::string memo + ); + + // Given an incoming view key, extract the coin's nonce, diversifier, value, and memo + IdentifiedCoinData identify(const IncomingViewKey& incoming_view_key); + + // Given a full view key, extract the coin's serial number and tag + RecoveredCoinData recover(const FullViewKey& full_view_key, const IdentifiedCoinData& data); + +protected: + bool validate(const IncomingViewKey& incoming_view_key, IdentifiedCoinData& data); + +public: + const Params* params; + char type; // type flag + GroupElement S, K, C; // serial commitment, recovery key, value commitment + AEADEncryptedData r_; // encrypted recipient data + uint64_t v; // value + + // Serialization depends on the coin type + ADD_SERIALIZE_METHODS; + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(type); + READWRITE(S); + READWRITE(K); + READWRITE(C); + READWRITE(r_); + + if (type == COIN_TYPE_MINT) { + READWRITE(v); + } + } +}; + +} + +#endif diff --git a/src/libspark/grootle.cpp b/src/libspark/grootle.cpp new file mode 100644 index 0000000000..9d1ebdb7d9 --- /dev/null +++ b/src/libspark/grootle.cpp @@ -0,0 +1,564 @@ +#include "grootle.h" +#include "transcript.h" + +namespace spark { + +// Useful scalar constants +const Scalar ZERO = Scalar(uint64_t(0)); +const Scalar ONE = Scalar(uint64_t(1)); +const Scalar TWO = Scalar(uint64_t(2)); + +Grootle::Grootle( + const GroupElement& H_, + const std::vector& Gi_, + const std::vector& Hi_, + const std::size_t n_, + const std::size_t m_) + : H (H_) + , Gi (Gi_) + , Hi (Hi_) + , n (n_) + , m (m_) +{ + if (!(n > 1 && m > 1)) { + throw std::invalid_argument("Bad Grootle size parameters!"); + } + if (Gi.size() != n*m || Hi.size() != n*m) { + throw std::invalid_argument("Bad Grootle generator size!"); + } +} + +// Compute a delta function vector +static inline std::vector convert_to_sigma(std::size_t num, const std::size_t n, const std::size_t m) { + std::vector result; + result.reserve(n*m); + + for (std::size_t j = 0; j < m; j++) { + for (std::size_t i = 0; i < n; i++) { + if (i == (num % n)) { + result.emplace_back(ONE); + } else { + result.emplace_back(ZERO); + } + } + num /= n; + } + + return result; +} + +// Decompose an integer with arbitrary base and padded size +static inline std::vector decompose(std::size_t num, const std::size_t n, const std::size_t m) { + std::vector result; + result.reserve(m); + + while (num != 0) { + result.emplace_back(num % n); + num /= n; + } + result.resize(m); + + return result; +} + +// Compute a double Pedersen vector commitment +static inline GroupElement vector_commit(const std::vector& Gi, const std::vector& Hi, const std::vector& a, const std::vector& b, const GroupElement& H, const Scalar& r) { + return secp_primitives::MultiExponent(Gi, a).get_multiple() + secp_primitives::MultiExponent(Hi, b).get_multiple() + H*r; +} + +// Compute a convolution with a degree-one polynomial +static inline void convolve(const Scalar& x_1, const Scalar& x_0, std::vector& coefficients) { + if (coefficients.empty()) { + throw std::runtime_error("Empty convolution coefficient vector!"); + } + + std::size_t degree = coefficients.size() - 1; + coefficients.emplace_back(x_1*coefficients[degree]); + for (std::size_t i = degree; i >=1; i--) { + coefficients[i] = x_0*coefficients[i] + x_1*coefficients[i-1]; + } + coefficients[0] *= x_0; +} + +static bool compute_fs( + const GrootleProof& proof, + const Scalar& x, + std::vector& f_, + const std::size_t n, + const std::size_t m) { + for (std::size_t j = 0; j < proof.f.size(); ++j) { + if(proof.f[j] == x) + return false; + } + + f_.reserve(n * m); + for (std::size_t j = 0; j < m; ++j) + { + f_.push_back(Scalar(uint64_t(0))); + Scalar temp; + std::size_t k = n - 1; + for (std::size_t i = 0; i < k; ++i) + { + temp += proof.f[j * k + i]; + f_.emplace_back(proof.f[j * k + i]); + } + f_[j * n] = x - temp; + } + return true; +} + +static void compute_batch_fis( + Scalar& f_sum, + const Scalar& f_i, + int j, + const std::vector& f, + const Scalar& y, + std::vector::iterator& ptr, + std::vector::iterator start_ptr, + std::vector::iterator end_ptr, + const std::size_t n) { + j--; + if (j == -1) + { + if(ptr >= start_ptr && ptr < end_ptr){ + *ptr++ += f_i * y; + f_sum += f_i; + } + return; + } + + Scalar t; + + for (std::size_t i = 0; i < n; i++) + { + t = f[j * n + i]; + t *= f_i; + + compute_batch_fis(f_sum, t, j, f, y, ptr, start_ptr, end_ptr, n); + } +} + +void Grootle::prove( + const std::size_t l, + const Scalar& s, + const std::vector& S, + const GroupElement& S1, + const Scalar& v, + const std::vector& V, + const GroupElement& V1, + GrootleProof& proof) { + // Check statement validity + std::size_t N = (std::size_t) pow(n, m); // padded input size + std::size_t size = S.size(); // actual input size + if (l >= size) { + throw std::invalid_argument("Bad Grootle secret index!"); + } + if (V.size() != S.size()) { + throw std::invalid_argument("Bad Grootle input vector sizes!"); + } + if (size > N || size == 0) { + throw std::invalid_argument("Bad Grootle size parameter!"); + } + if (S[l] + S1.inverse() != H*s) { + throw std::invalid_argument("Bad Grootle proof statement!"); + } + if (V[l] + V1.inverse() != H*v) { + throw std::invalid_argument("Bad Grootle proof statement!"); + } + + // Set up transcript + Transcript transcript("SPARK_GROOTLE"); + transcript.add("H", H); + transcript.add("Gi", Gi); + transcript.add("Hi", Hi); + transcript.add("n", Scalar(n)); + transcript.add("m", Scalar(m)); + transcript.add("S", S); + transcript.add("S1", S1); + transcript.add("V", V); + transcript.add("V1", V1); + + // Compute A + std::vector a; + a.resize(n*m); + for (std::size_t j = 0; j < m; j++) { + for (std::size_t i = 1; i < n; i++) { + a[j*n + i].randomize(); + a[j*n] -= a[j*n + i]; + } + } + std::vector d; + d.resize(n*m); + for (std::size_t i = 0; i < n*m; i++) { + d[i] = a[i].square().negate(); + } + Scalar rA; + rA.randomize(); + proof.A = vector_commit(Gi, Hi, a, d, H, rA); + + // Compute B + std::vector sigma = convert_to_sigma(l, n, m); + std::vector c; + c.resize(n*m); + for (std::size_t i = 0; i < n*m; i++) { + c[i] = a[i]*(ONE - TWO*sigma[i]); + } + Scalar rB; + rB.randomize(); + proof.B = vector_commit(Gi, Hi, sigma, c, H, rB); + + // Compute convolution terms + std::vector> P_i_j; + P_i_j.resize(size); + for (std::size_t i = 0; i < size - 1; ++i) + { + std::vector& coefficients = P_i_j[i]; + std::vector I = decompose(i, n, m); + coefficients.push_back(a[I[0]]); + coefficients.push_back(sigma[I[0]]); + for (std::size_t j = 1; j < m; ++j) { + convolve(sigma[j*n + I[j]], a[j*n + I[j]], coefficients); + } + } + + /* + * To optimize calculation of sum of all polynomials indices 's' = size-1 through 'n^m-1' we use the + * fact that sum of all of elements in each row of 'a' array is zero. Computation is done by going + * through n-ary representation of 's' and increasing "digit" at each position to 'n-1' one by one. + * During every step digits at higher positions are fixed and digits at lower positions go through all + * possible combinations with a total corresponding polynomial sum of 'x^j'. + * + * The math behind optimization (TeX notation): + * + * \sum_{i=s+1}^{N-1}p_i(x) = + * \sum_{j=0}^{m-1} + * \left[ + * \left( \sum_{i=s_j+1}^{n-1}(\delta_{l_j,i}x+a_{j,i}) \right) + * \left( \prod_{k=j}^{m-1}(\delta_{l_k,s_k}x+a_{k,s_k}) \right) + * x^j + * \right] + */ + + std::vector I = decompose(size - 1, n, m); + std::vector lj = decompose(l, n, m); + + std::vector p_i_sum; + p_i_sum.emplace_back(ONE); + std::vector> partial_p_s; + + // Pre-calculate product parts and calculate p_s(x) at the same time, put the latter into p_i_sum + for (std::ptrdiff_t j = m - 1; j >= 0; j--) { + partial_p_s.push_back(p_i_sum); + convolve(sigma[j*n + I[j]], a[j*n + I[j]], p_i_sum); + } + + for (std::size_t j = 0; j < m; j++) { + // \sum_{i=s_j+1}^{n-1}(\delta_{l_j,i}x+a_{j,i}) + Scalar a_sum(uint64_t(0)); + for (std::size_t i = I[j] + 1; i < n; i++) + a_sum += a[j * n + i]; + Scalar x_sum(uint64_t(lj[j] >= I[j]+1 ? 1 : 0)); + + // Multiply by \prod_{k=j}^{m-1}(\delta_{l_k,s_k}x+a_{k,s_k}) + std::vector &polynomial = partial_p_s[m - j - 1]; + convolve(x_sum, a_sum, polynomial); + + // Multiply by x^j and add to the result + for (std::size_t k = 0; k < m - j; k++) + p_i_sum[j + k] += polynomial[k]; + } + + P_i_j[size - 1] = p_i_sum; + + // Perform the commitment offsets + std::vector S_offset(S); + std::vector V_offset(V); + GroupElement S1_inverse = S1.inverse(); + GroupElement V1_inverse = V1.inverse(); + for (std::size_t k = 0; k < S_offset.size(); k++) { + S_offset[k] += S1_inverse; + V_offset[k] += V1_inverse; + } + + // Generate masks + std::vector rho_S, rho_V; + rho_S.resize(m); + rho_V.resize(m); + for (std::size_t j = 0; j < m; j++) { + rho_S[j].randomize(); + rho_V[j].randomize(); + } + + proof.X.reserve(m); + proof.X1.reserve(m); + for (std::size_t j = 0; j < m; ++j) + { + std::vector P_i; + P_i.reserve(size); + for (std::size_t i = 0; i < size; ++i){ + P_i.emplace_back(P_i_j[i][j]); + } + + // S + secp_primitives::MultiExponent mult_S(S_offset, P_i); + proof.X.emplace_back(mult_S.get_multiple() + H*rho_S[j]); + + // V + secp_primitives::MultiExponent mult_V(V_offset, P_i); + proof.X1.emplace_back(mult_V.get_multiple() + H*rho_V[j]); + } + + // Challenge + transcript.add("A", proof.A); + transcript.add("B", proof.B); + transcript.add("X", proof.X); + transcript.add("X1", proof.X1); + Scalar x = transcript.challenge("x"); + + // Compute f + proof.f.reserve(m*(n - 1)); + for (std::size_t j = 0; j < m; j++) + { + for (std::size_t i = 1; i < n; i++) { + proof.f.emplace_back(sigma[(j * n) + i] * x + a[(j * n) + i]); + } + } + + // Compute zA, zC + proof.z = rB * x + rA; + + // Compute zS, zV + proof.zS = s * x.exponent(uint64_t(m)); + proof.zV = v * x.exponent(uint64_t(m)); + Scalar sumS, sumV; + + Scalar x_powers(uint64_t(1)); + for (std::size_t j = 0; j < m; ++j) { + sumS += (rho_S[j] * x_powers); + sumV += (rho_V[j] * x_powers); + x_powers *= x; + } + proof.zS -= sumS; + proof.zV -= sumV; +} + +// Verify a single proof +bool Grootle::verify( + const std::vector& S, + const GroupElement& S1, + const std::vector& V, + const GroupElement& V1, + const std::size_t size, + const GrootleProof& proof) { + std::vector S1_batch = {S1}; + std::vector V1_batch = {V1}; + std::vector size_batch = {size}; + std::vector proof_batch = {proof}; + + return verify(S, S1_batch, V, V1_batch, size_batch, proof_batch); +} + +// Verify a batch of proofs +bool Grootle::verify( + const std::vector& S, + const std::vector& S1, + const std::vector& V, + const std::vector& V1, + const std::vector& sizes, + const std::vector& proofs) { + // Sanity checks + if (n < 2 || m < 2) { + LogPrintf("Verifier parameters are invalid"); + return false; + } + std::size_t M = proofs.size(); + std::size_t N = (std::size_t)pow(n, m); + + if (S.size() == 0) { + LogPrintf("Cannot have empty commitment set"); + return false; + } + if (S.size() > N) { + LogPrintf("Commitment set is too large"); + return false; + } + if (S.size() != V.size()) { + LogPrintf("Commitment set sizes do not match"); + return false; + } + if (S1.size() != M || V1.size() != M) { + LogPrintf("Invalid number of offsets provided"); + return false; + } + if (sizes.size() != M) { + LogPrintf("Invalid set size vector size"); + return false; + } + + // Check proof semantics + for (std::size_t t = 0; t < M; t++) { + GrootleProof proof = proofs[t]; + if (proof.X.size() != m || proof.X1.size() != m) { + LogPrintf("Bad proof vector size!"); + return false; + } + if (proof.f.size() != m*(n-1)) { + LogPrintf("Bad proof vector size!"); + return false; + } + } + + // Commitment binding weight; intentionally restricted range for efficiency, but must be nonzero + // NOTE: this may initialize with a PRNG, which should be sufficient for this use + std::random_device generator; + std::uniform_int_distribution distribution; + Scalar bind_weight(ZERO); + while (bind_weight == ZERO) { + bind_weight = Scalar(distribution(generator)); + } + + // Bind the commitment lists + std::vector commits; + commits.reserve(S.size()); + for (std::size_t i = 0; i < S.size(); i++) { + commits.emplace_back(S[i] + V[i]*bind_weight); + } + + // Final batch multiscalar multiplication + Scalar H_scalar; + std::vector Gi_scalars; + std::vector Hi_scalars; + std::vector commit_scalars; + Gi_scalars.resize(n*m); + Hi_scalars.resize(n*m); + commit_scalars.resize(commits.size()); + + // Set up the final batch elements + std::vector points; + std::vector scalars; + std::size_t final_size = 1 + 2*m*n + commits.size(); // F, (Gi), (Hi), (commits) + for (std::size_t t = 0; t < M; t++) { + final_size += 2 + proofs[t].X.size() + proofs[t].X1.size(); // A, B, (Gs), (Gv) + } + points.reserve(final_size); + scalars.reserve(final_size); + + // Index decomposition, which is common among all proofs + std::vector > I_; + I_.reserve(commits.size()); + I_.resize(commits.size()); + for (std::size_t i = 0; i < commits.size(); i++) { + I_[i] = decompose(i, n, m); + } + + // Process all proofs + for (std::size_t t = 0; t < M; t++) { + GrootleProof proof = proofs[t]; + + // Reconstruct the challenge + Transcript transcript("SPARK_GROOTLE"); + transcript.add("H", H); + transcript.add("Gi", Gi); + transcript.add("Hi", Hi); + transcript.add("n", Scalar(n)); + transcript.add("m", Scalar(m)); + transcript.add("S", std::vector(S.begin() + S.size() - sizes[t], S.end())); + transcript.add("S1", S1[t]); + transcript.add("V", std::vector(V.begin() + V.size() - sizes[t], V.end())); + transcript.add("V1", V1[t]); + transcript.add("A", proof.A); + transcript.add("B", proof.B); + transcript.add("X", proof.X); + transcript.add("X1", proof.X1); + Scalar x = transcript.challenge("x"); + + // Generate nonzero random verifier weights (the randomization already asserts nonzero) + Scalar w1, w2; + w1.randomize(); + w2.randomize(); + + // Reconstruct f-matrix + std::vector f_; + if (!compute_fs(proof, x, f_, n, m)) { + LogPrintf("Invalid matrix reconstruction"); + return false; + } + + // Effective set size + const std::size_t size = sizes[t]; + + // A, B (and associated commitments) + points.emplace_back(proof.A); + scalars.emplace_back(w1.negate()); + points.emplace_back(proof.B); + scalars.emplace_back(x.negate() * w1); + + H_scalar += proof.z * w1; + for (std::size_t i = 0; i < m * n; i++) { + Gi_scalars[i] += f_[i] * w1; + Hi_scalars[i] += f_[i]*(x - f_[i]) * w1; + } + + // Input sets + H_scalar += (proof.zS + bind_weight * proof.zV) * w2.negate(); + + Scalar f_sum; + Scalar f_i(uint64_t(1)); + std::vector::iterator ptr = commit_scalars.begin() + commits.size() - size; + compute_batch_fis(f_sum, f_i, m, f_, w2, ptr, ptr, ptr + size - 1, n); + + Scalar pow(uint64_t(1)); + std::vector f_part_product; + for (std::ptrdiff_t j = m - 1; j >= 0; j--) { + f_part_product.push_back(pow); + pow *= f_[j*n + I_[size - 1][j]]; + } + + Scalar x_powers(uint64_t(1)); + for (std::size_t j = 0; j < m; j++) { + Scalar fi_sum(uint64_t(0)); + for (std::size_t i = I_[size - 1][j] + 1; i < n; i++) + fi_sum += f_[j*n + i]; + pow += fi_sum * x_powers * f_part_product[m - j - 1]; + x_powers *= x; + } + + f_sum += pow; + commit_scalars[commits.size() - 1] += pow * w2; + + // S1, V1 + points.emplace_back(S1[t] + V1[t] * bind_weight); + scalars.emplace_back(f_sum * w2.negate()); + + // (X), (X1) + x_powers = Scalar(uint64_t(1)); + for (std::size_t j = 0; j < m; j++) { + points.emplace_back(proof.X[j] + proof.X1[j] * bind_weight); + scalars.emplace_back(x_powers.negate() * w2); + x_powers *= x; + } + } + + // Add common generators + points.emplace_back(H); + scalars.emplace_back(H_scalar); + for (std::size_t i = 0; i < m * n; i++) { + points.emplace_back(Gi[i]); + scalars.emplace_back(Gi_scalars[i]); + points.emplace_back(Hi[i]); + scalars.emplace_back(Hi_scalars[i]); + } + for (std::size_t i = 0; i < commits.size(); i++) { + points.emplace_back(commits[i]); + scalars.emplace_back(commit_scalars[i]); + } + + // Verify the batch + secp_primitives::MultiExponent result(points, scalars); + if (result.get_multiple().isInfinity()) { + return true; + } + return false; +} + +} \ No newline at end of file diff --git a/src/libspark/grootle.h b/src/libspark/grootle.h new file mode 100644 index 0000000000..d8b03d7d2d --- /dev/null +++ b/src/libspark/grootle.h @@ -0,0 +1,53 @@ +#ifndef FIRO_LIBSPARK_GROOTLE_H +#define FIRO_LIBSPARK_GROOTLE_H + +#include "grootle_proof.h" +#include +#include +#include "util.h" + +namespace spark { + +class Grootle { + +public: + Grootle( + const GroupElement& H, + const std::vector& Gi, + const std::vector& Hi, + const std::size_t n, + const std::size_t m + ); + + void prove(const std::size_t l, + const Scalar& s, + const std::vector& S, + const GroupElement& S1, + const Scalar& v, + const std::vector& V, + const GroupElement& V1, + GrootleProof& proof); + bool verify(const std::vector& S, + const GroupElement& S1, + const std::vector& V, + const GroupElement& V1, + const std::size_t size, + const GrootleProof& proof); // single proof + bool verify(const std::vector& S, + const std::vector& S1, + const std::vector& V, + const std::vector& V1, + const std::vector& sizes, + const std::vector& proofs); // batch of proofs + +private: + GroupElement H; + std::vector Gi; + std::vector Hi; + std::size_t n; + std::size_t m; +}; + +} + +#endif diff --git a/src/libspark/grootle_proof.h b/src/libspark/grootle_proof.h new file mode 100644 index 0000000000..3530f7e343 --- /dev/null +++ b/src/libspark/grootle_proof.h @@ -0,0 +1,45 @@ +#ifndef FIRO_LIBSPARK_GROOTLE_PROOF_H +#define FIRO_LIBSPARK_GROOTLE_PROOF_H + +#include "params.h" + +namespace spark { + +class GrootleProof { +public: + + inline std::size_t memoryRequired() const { + return 2*GroupElement::memoryRequired() + X.size()*GroupElement::memoryRequired() + X1.size()*GroupElement::memoryRequired() + f.size()*Scalar::memoryRequired() + 3*Scalar::memoryRequired(); + } + + inline std::size_t memoryRequired(int n, int m) const { + return 2*GroupElement::memoryRequired() + 2*m*GroupElement::memoryRequired() + m*(n-1)*Scalar::memoryRequired() + 3*Scalar::memoryRequired(); + } + + ADD_SERIALIZE_METHODS; + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(A); + READWRITE(B); + READWRITE(X); + READWRITE(X1); + READWRITE(f); + READWRITE(z); + READWRITE(zS); + READWRITE(zV); + } + +public: + GroupElement A; + GroupElement B; + std::vector X; + std::vector X1; + std::vector f; + Scalar z; + Scalar zS; + Scalar zV; +}; + +} + +#endif diff --git a/src/libspark/hash.cpp b/src/libspark/hash.cpp new file mode 100644 index 0000000000..050b8e1c8d --- /dev/null +++ b/src/libspark/hash.cpp @@ -0,0 +1,148 @@ +#include "hash.h" + +namespace spark { + +using namespace secp_primitives; + +// Set up a labeled hash function +Hash::Hash(const std::string label) { + this->ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(this->ctx, EVP_blake2b512(), NULL); + + // Write the protocol and mode information + std::vector protocol(LABEL_PROTOCOL.begin(), LABEL_PROTOCOL.end()); + EVP_DigestUpdate(this->ctx, protocol.data(), protocol.size()); + EVP_DigestUpdate(this->ctx, &HASH_MODE_FUNCTION, sizeof(HASH_MODE_FUNCTION)); + + // Include the label with size + include_size(label.size()); + std::vector label_bytes(label.begin(), label.end()); + EVP_DigestUpdate(this->ctx, label_bytes.data(), label_bytes.size()); +} + +// Clean up +Hash::~Hash() { + EVP_MD_CTX_free(this->ctx); +} + +// Include serialized data in the hash function +void Hash::include(CDataStream& data) { + include_size(data.size()); + EVP_DigestUpdate(this->ctx, reinterpret_cast(data.data()), data.size()); +} + +// Finalize the hash function to a scalar +Scalar Hash::finalize_scalar() { + // Ensure we can properly populate a scalar + if (EVP_MD_size(EVP_blake2b512()) < SCALAR_ENCODING) { + throw std::runtime_error("Bad hash size!"); + } + + std::vector hash; + hash.resize(EVP_MD_size(EVP_blake2b512())); + unsigned char counter = 0; + + EVP_MD_CTX* state_counter; + state_counter = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_counter, EVP_blake2b512(), NULL); + + EVP_MD_CTX* state_finalize; + state_finalize = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_finalize, EVP_blake2b512(), NULL); + + while (1) { + // Prepare temporary state for counter testing + EVP_MD_CTX_copy_ex(state_counter, this->ctx); + + // Embed the counter + EVP_DigestUpdate(state_counter, &counter, sizeof(counter)); + + // Finalize the hash with a temporary state + EVP_MD_CTX_copy_ex(state_finalize, state_counter); + unsigned int TEMP; // We already know the digest length! + EVP_DigestFinal_ex(state_finalize, hash.data(), &TEMP); + + // Check for scalar validity + Scalar candidate; + try { + candidate.deserialize(hash.data()); + + EVP_MD_CTX_free(state_counter); + EVP_MD_CTX_free(state_finalize); + + return candidate; + } catch (...) { + counter++; + } + } +} + +// Finalize the hash function to a group element +GroupElement Hash::finalize_group() { + const int GROUP_ENCODING = 34; + const unsigned char ZERO = 0; + + // Ensure we can properly populate a + if (EVP_MD_size(EVP_blake2b512()) < GROUP_ENCODING) { + throw std::runtime_error("Bad hash size!"); + } + + std::vector hash; + hash.resize(EVP_MD_size(EVP_blake2b512())); + unsigned char counter = 0; + + EVP_MD_CTX* state_counter; + state_counter = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_counter, EVP_blake2b512(), NULL); + + EVP_MD_CTX* state_finalize; + state_finalize = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_finalize, EVP_blake2b512(), NULL); + + while (1) { + // Prepare temporary state for counter testing + EVP_MD_CTX_copy_ex(state_counter, this->ctx); + + // Embed the counter + EVP_DigestUpdate(state_counter, &counter, sizeof(counter)); + + // Finalize the hash with a temporary state + EVP_MD_CTX_copy_ex(state_finalize, state_counter); + unsigned int TEMP; // We already know the digest length! + EVP_DigestFinal_ex(state_finalize, hash.data(), &TEMP); + + // Assemble the serialized input: + // bytes 0..31: x coordinate + // byte 32: even/odd + // byte 33: zero (this point is not infinity) + unsigned char candidate_bytes[GROUP_ENCODING]; + memcpy(candidate_bytes, hash.data(), 33); + memcpy(candidate_bytes + 33, &ZERO, 1); + GroupElement candidate; + try { + candidate.deserialize(candidate_bytes); + + // Deserialization can succeed even with an invalid result + if (!candidate.isMember()) { + counter++; + continue; + } + + EVP_MD_CTX_free(state_counter); + EVP_MD_CTX_free(state_finalize); + + return candidate; + } catch (...) { + counter++; + } + } +} + +// Include a serialized size in the hash function +void Hash::include_size(std::size_t size) { + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << size; + EVP_DigestUpdate(this->ctx, reinterpret_cast(stream.data()), stream.size()); +} + +} \ No newline at end of file diff --git a/src/libspark/hash.h b/src/libspark/hash.h new file mode 100644 index 0000000000..dfd63ccb5e --- /dev/null +++ b/src/libspark/hash.h @@ -0,0 +1,25 @@ +#ifndef FIRO_SPARK_HASH_H +#define FIRO_SPARK_HASH_H +#include +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +class Hash { +public: + Hash(const std::string label); + ~Hash(); + void include(CDataStream& data); + Scalar finalize_scalar(); + GroupElement finalize_group(); + +private: + void include_size(std::size_t size); + EVP_MD_CTX* ctx; +}; + +} + +#endif diff --git a/src/libspark/kdf.cpp b/src/libspark/kdf.cpp new file mode 100644 index 0000000000..920974e976 --- /dev/null +++ b/src/libspark/kdf.cpp @@ -0,0 +1,57 @@ +#include "kdf.h" + +namespace spark { + +// Set up a labeled KDF +KDF::KDF(const std::string label) { + this->ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(this->ctx, EVP_blake2b512(), NULL); + + // Write the protocol and mode information + std::vector protocol(LABEL_PROTOCOL.begin(), LABEL_PROTOCOL.end()); + EVP_DigestUpdate(this->ctx, protocol.data(), protocol.size()); + EVP_DigestUpdate(this->ctx, &HASH_MODE_KDF, sizeof(HASH_MODE_KDF)); + + // Include the label with size + include_size(label.size()); + std::vector label_bytes(label.begin(), label.end()); + EVP_DigestUpdate(this->ctx, label_bytes.data(), label_bytes.size()); +} + +// Clean up +KDF::~KDF() { + EVP_MD_CTX_free(this->ctx); +} + +// Include serialized data in the KDF +void KDF::include(CDataStream& data) { + include_size(data.size()); + EVP_DigestUpdate(this->ctx, reinterpret_cast(data.data()), data.size()); +} + +// Finalize the KDF with arbitrary size +std::vector KDF::finalize(std::size_t size) { + // Assert valid size + const std::size_t hash_size = EVP_MD_size(EVP_blake2b512()); + if (size > hash_size) { + throw std::invalid_argument("Requested KDF size is too large"); + } + + std::vector result; + result.resize(hash_size); + + unsigned int TEMP; + EVP_DigestFinal_ex(this->ctx, result.data(), &TEMP); + result.resize(size); + + return result; +} + +// Include a serialized size in the KDF +void KDF::include_size(std::size_t size) { + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << size; + EVP_DigestUpdate(this->ctx, reinterpret_cast(stream.data()), stream.size()); +} + +} \ No newline at end of file diff --git a/src/libspark/kdf.h b/src/libspark/kdf.h new file mode 100644 index 0000000000..1c5348b0e6 --- /dev/null +++ b/src/libspark/kdf.h @@ -0,0 +1,22 @@ +#ifndef FIRO_SPARK_KDF_H +#define FIRO_SPARK_KDF_H +#include +#include "util.h" + +namespace spark { + +class KDF { +public: + KDF(const std::string label); + ~KDF(); + void include(CDataStream& data); + std::vector finalize(std::size_t size); + +private: + void include_size(std::size_t size); + EVP_MD_CTX* ctx; +}; + +} + +#endif diff --git a/src/libspark/keys.cpp b/src/libspark/keys.cpp new file mode 100644 index 0000000000..58047c4cbe --- /dev/null +++ b/src/libspark/keys.cpp @@ -0,0 +1,118 @@ +#include "keys.h" + +namespace spark { + +using namespace secp_primitives; + +SpendKey::SpendKey() {} +SpendKey::SpendKey(const Params* params) { + this->params = params; + this->s1.randomize(); + this->s2.randomize(); + this->r.randomize(); +} + +const Params* SpendKey::get_params() const { + return this->params; +} + +const Scalar& SpendKey::get_s1() const { + return this->s1; +} + +const Scalar& SpendKey::get_s2() const { + return this->s2; +} + +const Scalar& SpendKey::get_r() const { + return this->r; +} + +FullViewKey::FullViewKey() {} +FullViewKey::FullViewKey(const SpendKey& spend_key) { + this->params = spend_key.get_params(); + this->s1 = spend_key.get_s1(); + this->s2 = spend_key.get_s2(); + this->D = this->params->get_G()*spend_key.get_r(); + this->P2 = this->params->get_F()*this->s2 + this->D; +} + +const Params* FullViewKey::get_params() const { + return this->params; +} + +const Scalar& FullViewKey::get_s1() const { + return this->s1; +} + +const Scalar& FullViewKey::get_s2() const { + return this->s2; +} + +const GroupElement& FullViewKey::get_D() const { + return this->D; +} + +const GroupElement& FullViewKey::get_P2() const { + return this->P2; +} + +IncomingViewKey::IncomingViewKey() {} +IncomingViewKey::IncomingViewKey(const FullViewKey& full_view_key) { + this->params = full_view_key.get_params(); + this->s1 = full_view_key.get_s1(); + this->P2 = full_view_key.get_P2(); +} + +const Params* IncomingViewKey::get_params() const { + return this->params; +} + +const Scalar& IncomingViewKey::get_s1() const { + return this->s1; +} + +const GroupElement& IncomingViewKey::get_P2() const { + return this->P2; +} + +uint64_t IncomingViewKey::get_diversifier(const std::vector& d) const { + // Assert proper size + if (d.size() != AES_BLOCKSIZE) { + throw std::invalid_argument("Bad encrypted diversifier"); + } + + // Decrypt the diversifier; this is NOT AUTHENTICATED and MUST be externally checked for validity against a claimed address + std::vector key = SparkUtils::kdf_diversifier(this->s1); + uint64_t i = SparkUtils::diversifier_decrypt(key, d); + + return i; +} + +Address::Address() {} +Address::Address(const IncomingViewKey& incoming_view_key, const uint64_t i) { + // Encrypt the diversifier + std::vector key = SparkUtils::kdf_diversifier(incoming_view_key.get_s1()); + this->params = incoming_view_key.get_params(); + this->d = SparkUtils::diversifier_encrypt(key, i); + this->Q1 = SparkUtils::hash_div(this->d)*incoming_view_key.get_s1(); + this->Q2 = this->params->get_F()*SparkUtils::hash_Q2(incoming_view_key.get_s1(), i) + incoming_view_key.get_P2(); +} + +const Params* Address::get_params() const { + return this->params; +} + +const std::vector& Address::get_d() const { + return this->d; +} + +const GroupElement& Address::get_Q1() const { + return this->Q1; +} + +const GroupElement& Address::get_Q2() const { + return this->Q2; +} + +} diff --git a/src/libspark/keys.h b/src/libspark/keys.h new file mode 100644 index 0000000000..1f34c80bc9 --- /dev/null +++ b/src/libspark/keys.h @@ -0,0 +1,72 @@ +#ifndef FIRO_SPARK_KEYS_H +#define FIRO_SPARK_KEYS_H +#include "params.h" +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +class SpendKey { +public: + SpendKey(); + SpendKey(const Params* params); + const Params* get_params() const; + const Scalar& get_s1() const; + const Scalar& get_s2() const; + const Scalar& get_r() const; + +private: + const Params* params; + Scalar s1, s2, r; +}; + +class FullViewKey { +public: + FullViewKey(); + FullViewKey(const SpendKey& spend_key); + const Params* get_params() const; + const Scalar& get_s1() const; + const Scalar& get_s2() const; + const GroupElement& get_D() const; + const GroupElement& get_P2() const; + +private: + const Params* params; + Scalar s1, s2; + GroupElement D, P2; +}; + +class IncomingViewKey { +public: + IncomingViewKey(); + IncomingViewKey(const FullViewKey& full_view_key); + const Params* get_params() const; + const Scalar& get_s1() const; + const GroupElement& get_P2() const; + uint64_t get_diversifier(const std::vector& d) const; + +private: + const Params* params; + Scalar s1; + GroupElement P2; +}; + +class Address { +public: + Address(); + Address(const IncomingViewKey& incoming_view_key, const uint64_t i); + const Params* get_params() const; + const std::vector& get_d() const; + const GroupElement& get_Q1() const; + const GroupElement& get_Q2() const; + +private: + const Params* params; + std::vector d; + GroupElement Q1, Q2; +}; + +} + +#endif diff --git a/src/libspark/mint_transaction.cpp b/src/libspark/mint_transaction.cpp new file mode 100644 index 0000000000..7c99e6f498 --- /dev/null +++ b/src/libspark/mint_transaction.cpp @@ -0,0 +1,43 @@ +#include "mint_transaction.h" + +namespace spark { + +MintTransaction::MintTransaction( + const Params* params, + const Address& address, + uint64_t v, + const std::string memo +) { + this->params = params; + + // Generate the coin + Scalar k; + k.randomize(); + this->coin = Coin( + this->params, + COIN_TYPE_MINT, + k, + address, + v, + memo + ); + + // Generate the balance proof + Schnorr schnorr(this->params->get_H()); + schnorr.prove( + SparkUtils::hash_val(k), + this->coin.C + this->params->get_G().inverse()*Scalar(v), + this->balance_proof + ); +} + +bool MintTransaction::verify() { + // Verify the balance proof + Schnorr schnorr(this->params->get_H()); + return schnorr.verify( + this->coin.C + this->params->get_G().inverse()*Scalar(this->coin.v), + this->balance_proof + ); +} + +} diff --git a/src/libspark/mint_transaction.h b/src/libspark/mint_transaction.h new file mode 100644 index 0000000000..7e14d45138 --- /dev/null +++ b/src/libspark/mint_transaction.h @@ -0,0 +1,30 @@ +#ifndef FIRO_SPARK_MINT_TRANSACTION_H +#define FIRO_SPARK_MINT_TRANSACTION_H +#include "keys.h" +#include "coin.h" +#include "schnorr.h" +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +class MintTransaction { +public: + MintTransaction( + const Params* params, + const Address& address, + uint64_t v, + const std::string memo + ); + bool verify(); + +private: + const Params* params; + Coin coin; + SchnorrProof balance_proof; +}; + +} + +#endif diff --git a/src/libspark/params.cpp b/src/libspark/params.cpp new file mode 100644 index 0000000000..a3ddb8392b --- /dev/null +++ b/src/libspark/params.cpp @@ -0,0 +1,137 @@ +#include "params.h" +#include "chainparams.h" +#include "util.h" + +namespace spark { + + CCriticalSection Params::cs_instance; + std::unique_ptr Params::instance; + +// Protocol parameters for deployment +Params const* Params::get_default() { + if (instance) { + return instance.get(); + } else { + LOCK(cs_instance); + if (instance) { + return instance.get(); + } + + std::size_t memo_bytes = 32; + std::size_t max_M_range = 16; + std::size_t n_grootle = 16; + std::size_t m_grootle = 4; + + instance.reset(new Params(memo_bytes, max_M_range, n_grootle, m_grootle)); + return instance.get(); + } +} + +// Protocol parameters for testing +Params const* Params::get_test() { + if (instance) { + return instance.get(); + } else { + LOCK(cs_instance); + if (instance) { + return instance.get(); + } + + std::size_t memo_bytes = 32; + std::size_t max_M_range = 16; + std::size_t n_grootle = 2; + std::size_t m_grootle = 4; + + instance.reset(new Params(memo_bytes, max_M_range, n_grootle, m_grootle)); + return instance.get(); + } +} + +Params::Params( + const std::size_t memo_bytes, + const std::size_t max_M_range, + const std::size_t n_grootle, + const std::size_t m_grootle +) +{ + // Global generators + this->F = SparkUtils::hash_generator(LABEL_GENERATOR_F); + this->G = SparkUtils::hash_generator(LABEL_GENERATOR_G); + this->H = SparkUtils::hash_generator(LABEL_GENERATOR_H); + this->U = SparkUtils::hash_generator(LABEL_GENERATOR_U); + + // Coin parameters + this->memo_bytes = memo_bytes; + + // Range proof parameters + this->max_M_range = max_M_range; + this->G_range.resize(64*max_M_range); + this->H_range.resize(64*max_M_range); + for (std::size_t i = 0; i < 64*max_M_range; i++) { + this->G_range[i] = SparkUtils::hash_generator(LABEL_GENERATOR_G_RANGE + " " + std::to_string(i)); + this->H_range[i] = SparkUtils::hash_generator(LABEL_GENERATOR_H_RANGE + " " + std::to_string(i)); + } + + // One-of-many parameters + if (n_grootle < 2 || m_grootle < 3) { + throw std::invalid_argument("Bad Grootle parameteres"); + } + this->n_grootle = n_grootle; + this->m_grootle = m_grootle; + this->G_grootle.resize(n_grootle * m_grootle); + this->H_grootle.resize(n_grootle * m_grootle); + for (std::size_t i = 0; i < n_grootle * m_grootle; i++) { + this->G_grootle[i] = SparkUtils::hash_generator(LABEL_GENERATOR_G_GROOTLE + " " + std::to_string(i)); + this->H_grootle[i] = SparkUtils::hash_generator(LABEL_GENERATOR_H_GROOTLE + " " + std::to_string(i)); + } +} + +const GroupElement& Params::get_F() const { + return this->F; +} + +const GroupElement& Params::get_G() const { + return this->G; +} + +const GroupElement& Params::get_H() const { + return this->H; +} + +const GroupElement& Params::get_U() const { + return this->U; +} + +const std::size_t Params::get_memo_bytes() const { + return this->memo_bytes; +} + +const std::vector& Params::get_G_range() const { + return this->G_range; +} + +const std::vector& Params::get_H_range() const { + return this->H_range; +} + +const std::vector& Params::get_G_grootle() const { + return this->G_grootle; +} + +const std::vector& Params::get_H_grootle() const { + return this->H_grootle; +} + +std::size_t Params::get_max_M_range() const { + return this->max_M_range; +} + +std::size_t Params::get_n_grootle() const { + return this->n_grootle; +} + +std::size_t Params::get_m_grootle() const { + return this->m_grootle; +} + +} diff --git a/src/libspark/params.h b/src/libspark/params.h new file mode 100644 index 0000000000..e37855dab4 --- /dev/null +++ b/src/libspark/params.h @@ -0,0 +1,67 @@ +#ifndef FIRO_LIBSPARK_PARAMS_H +#define FIRO_LIBSPARK_PARAMS_H + +#include +#include +#include +#include + +using namespace secp_primitives; + +namespace spark { + +class Params { +public: + static Params const* get_default(); + static Params const* get_test(); + + const GroupElement& get_F() const; + const GroupElement& get_G() const; + const GroupElement& get_H() const; + const GroupElement& get_U() const; + + const std::size_t get_memo_bytes() const; + + std::size_t get_max_M_range() const; + const std::vector& get_G_range() const; + const std::vector& get_H_range() const; + + std::size_t get_n_grootle() const; + std::size_t get_m_grootle() const; + const std::vector& get_G_grootle() const; + const std::vector& get_H_grootle() const; + +private: + Params( + const std::size_t memo_bytes, + const std::size_t max_M_range, + const std::size_t n_grootle, + const std::size_t m_grootle + ); + +private: + static CCriticalSection cs_instance; + static std::unique_ptr instance; + + // Global generators + GroupElement F; + GroupElement G; + GroupElement H; + GroupElement U; + + // Coin parameters + std::size_t memo_bytes; + + // Range proof parameters + std::size_t max_M_range; + std::vector G_range, H_range; + + // One-of-many parameters + std::size_t n_grootle, m_grootle; + std::vector G_grootle; + std::vector H_grootle; +}; + +} + +#endif diff --git a/src/libspark/schnorr.cpp b/src/libspark/schnorr.cpp new file mode 100644 index 0000000000..d7ffce9b10 --- /dev/null +++ b/src/libspark/schnorr.cpp @@ -0,0 +1,40 @@ +#include "schnorr.h" +#include "transcript.h" + +namespace spark { + +Schnorr::Schnorr(const GroupElement& G_): + G(G_) { +} + +Scalar Schnorr::challenge( + const GroupElement& Y, + const GroupElement& A) { + Transcript transcript("SPARK_SCHNORR"); + transcript.add("G", G); + transcript.add("Y", Y); + transcript.add("A", A); + + return transcript.challenge("c"); +} + +void Schnorr::prove(const Scalar& y, const GroupElement& Y, SchnorrProof& proof) { + // Check statement validity + if (!(G*y == Y)) { + throw std::invalid_argument("Bad Schnorr statement!"); + } + + Scalar r; + r.randomize(); + GroupElement A = G*r; + proof.c = challenge(Y, A); + proof.t = r + proof.c*y; +} + +bool Schnorr::verify(const GroupElement& Y, SchnorrProof& proof) { + Scalar c = challenge(Y, G*proof.t + Y.inverse()*proof.c); + + return c == proof.c; +} + +} diff --git a/src/libspark/schnorr.h b/src/libspark/schnorr.h new file mode 100644 index 0000000000..7cc40b0df1 --- /dev/null +++ b/src/libspark/schnorr.h @@ -0,0 +1,22 @@ +#ifndef FIRO_LIBSPARK_SCHNORR_H +#define FIRO_LIBSPARK_SCHNORR_H + +#include "schnorr_proof.h" + +namespace spark { + +class Schnorr { +public: + Schnorr(const GroupElement& G); + + void prove(const Scalar& y, const GroupElement& Y, SchnorrProof& proof); + bool verify(const GroupElement& Y, SchnorrProof& proof); + +private: + Scalar challenge(const GroupElement& Y, const GroupElement& A); + const GroupElement& G; +}; + +} + +#endif diff --git a/src/libspark/schnorr_proof.h b/src/libspark/schnorr_proof.h new file mode 100644 index 0000000000..c5fa96c776 --- /dev/null +++ b/src/libspark/schnorr_proof.h @@ -0,0 +1,27 @@ +#ifndef FIRO_LIBSPARK_SCHNORR_PROOF_H +#define FIRO_LIBSPARK_SCHNORR_PROOF_H + +#include "params.h" + +namespace spark { + +class SchnorrProof{ +public: + inline std::size_t memoryRequired() const { + return 2*Scalar::memoryRequired(); + } + + ADD_SERIALIZE_METHODS; + template + inline void SerializationOp(Stream& s, Operation ser_action) { + READWRITE(c); + READWRITE(t); + } + +public: + Scalar c; + Scalar t; +}; +} + +#endif diff --git a/src/libspark/spend_transaction.cpp b/src/libspark/spend_transaction.cpp new file mode 100644 index 0000000000..dc7585b4ac --- /dev/null +++ b/src/libspark/spend_transaction.cpp @@ -0,0 +1,317 @@ +#include "spend_transaction.h" + +namespace spark { + +SpendTransaction::SpendTransaction( + const Params* params, + const FullViewKey& full_view_key, + const SpendKey& spend_key, + const std::vector& in_coins, + const std::vector& inputs, + const uint64_t f, + const std::vector& outputs +) { + this->params = params; + + // Size parameters + const std::size_t w = inputs.size(); // number of consumed coins + const std::size_t t = outputs.size(); // number of generated coins + const std::size_t N = in_coins.size(); // size of cover set + + // Prepare input-related vectors + this->in_coins = in_coins; // input cover set + this->S1.reserve(w); // serial commitment offsets + this->C1.reserve(w); // value commitment offsets + this->grootle_proofs.reserve(w); // Grootle one-of-many proofs + this->T.reserve(w); // linking tags + + this->f = f; // fee + + // Prepare Chaum vectors + std::vector chaum_x, chaum_y, chaum_z; + + // Prepare output vector + this->out_coins.reserve(t); // coins + std::vector k; // nonces + + // Parse out serial and value commitments from the cover set for use in proofs + std::vector S, C; + S.resize(N); + C.resize(N); + for (std::size_t i = 0; i < N; i++) { + S[i] = in_coins[i].S; + C[i] = in_coins[i].C; + } + + // Prepare inputs + Grootle grootle( + this->params->get_H(), + this->params->get_G_grootle(), + this->params->get_H_grootle(), + this->params->get_n_grootle(), + this->params->get_m_grootle() + ); + for (std::size_t u = 0; u < w; u++) { + // Serial commitment offset + this->S1.emplace_back( + this->params->get_F()*inputs[u].s + + this->params->get_H().inverse()*SparkUtils::hash_ser1(inputs[u].s, full_view_key.get_D()) + + full_view_key.get_D() + ); + + // Value commitment offset + this->C1.emplace_back( + this->params->get_G()*Scalar(inputs[u].v) + + this->params->get_H()*SparkUtils::hash_val1(inputs[u].s, full_view_key.get_D()) + ); + + // Tags + this->T.emplace_back(inputs[u].T); + + // Grootle proof + this->grootle_proofs.emplace_back(); + std::size_t l = inputs[u].index; + grootle.prove( + l, + SparkUtils::hash_ser1(inputs[u].s, full_view_key.get_D()), + S, + this->S1.back(), + SparkUtils::hash_val(inputs[u].k) - SparkUtils::hash_val1(inputs[u].s, full_view_key.get_D()), + C, + this->C1.back(), + this->grootle_proofs.back() + ); + + // Chaum data + chaum_x.emplace_back(inputs[u].s); + chaum_y.emplace_back(spend_key.get_r()); + chaum_z.emplace_back(SparkUtils::hash_ser1(inputs[u].s, full_view_key.get_D()).negate()); + } + + // Generate output coins and prepare range proof vectors + std::vector range_v; + std::vector range_r; + std::vector range_C; + for (std::size_t j = 0; j < t; j++) { + // Nonce + k.emplace_back(); + k.back().randomize(); + + // Output coin + this->out_coins.emplace_back(); + this->out_coins.back() = Coin( + this->params, + COIN_TYPE_SPEND, + k.back(), + outputs[j].address, + outputs[j].v, + outputs[j].memo + ); + + // Range data + range_v.emplace_back(outputs[j].v); + range_r.emplace_back(SparkUtils::hash_val(k.back())); + range_C.emplace_back(this->out_coins.back().C); + } + + // Generate range proof + BPPlus range( + this->params->get_G(), + this->params->get_H(), + this->params->get_G_range(), + this->params->get_H_range(), + 64 + ); + range.prove( + range_v, + range_r, + range_C, + this->range_proof + ); + + // Generate the balance proof + Schnorr schnorr(this->params->get_H()); + GroupElement balance_statement; + Scalar balance_witness; + for (std::size_t u = 0; u < w; u++) { + balance_statement += this->C1[u]; + balance_witness += SparkUtils::hash_val1(inputs[u].s, full_view_key.get_D()); + } + for (std::size_t j = 0; j < t; j++) { + balance_statement += this->out_coins[j].C.inverse(); + balance_witness -= SparkUtils::hash_val(k[j]); + } + balance_statement += this->params->get_G()*Scalar(f); + schnorr.prove( + balance_witness, + balance_statement, + this->balance_proof + ); + + // Compute the binding hash + Scalar mu = hash_bind( + this->in_coins, + this->out_coins, + this->f, + this->S1, + this->C1, + this->T, + this->grootle_proofs, + this->balance_proof, + this->range_proof + ); + + // Compute the authorizing Chaum proof + Chaum chaum( + this->params->get_F(), + this->params->get_G(), + this->params->get_H(), + this->params->get_U() + ); + chaum.prove( + mu, + chaum_x, + chaum_y, + chaum_z, + this->S1, + this->T, + this->chaum_proof + ); +} + +bool SpendTransaction::verify() { + // Size parameters + const std::size_t w = this->grootle_proofs.size(); + const std::size_t t = this->out_coins.size(); + const std::size_t N = this->in_coins.size(); + + // Semantics + if (this->S1.size() != w || this->C1.size() != w || this->T.size() != w) { + throw std::invalid_argument("Bad spend transaction semantics"); + } + if (N > (std::size_t)pow(this->params->get_n_grootle(), this->params->get_m_grootle())) { + throw std::invalid_argument("Bad spend transaction semantics"); + } + + // Parse out serial and value commitments from the cover set for use in proofs + std::vector S, C; + S.resize(N); + C.resize(N); + for (std::size_t i = 0; i < N; i++) { + S[i] = this->in_coins[i].S; + C[i] = this->in_coins[i].C; + } + + // Parse out value commitments from the output set for use in proofs + std::vector C_out; + C_out.resize(t); + for (std::size_t j = 0; j < t; j++) { + C_out[j] = this->out_coins[j].C; + } + + // Consumed coins + Grootle grootle( + this->params->get_H(), + this->params->get_G_grootle(), + this->params->get_H_grootle(), + this->params->get_n_grootle(), + this->params->get_m_grootle() + ); + + // Verify all Grootle proofs in a batch + std::vector sizes; + for (std::size_t u = 0; u < w; u++) { + sizes.emplace_back(N); + } + if (!grootle.verify(S, this->S1, C, this->C1, sizes, this->grootle_proofs)) { + return false; + } + + // Compute the binding hash + Scalar mu = hash_bind( + this->in_coins, + this->out_coins, + this->f, + this->S1, + this->C1, + this->T, + this->grootle_proofs, + this->balance_proof, + this->range_proof + ); + + // Verify the authorizing Chaum proof + Chaum chaum( + this->params->get_F(), + this->params->get_G(), + this->params->get_H(), + this->params->get_U() + ); + if (!chaum.verify(mu, this->S1, this->T, this->chaum_proof)) { + return false; + } + + // Verify the aggregated range proof + BPPlus range( + this->params->get_G(), + this->params->get_H(), + this->params->get_G_range(), + this->params->get_H_range(), + 64 + ); + if (!range.verify(C_out, this->range_proof)) { + return false; + } + + // Verify the balance proof + Schnorr schnorr(this->params->get_H()); + GroupElement balance_statement; + for (std::size_t u = 0; u < w; u++) { + balance_statement += this->C1[u]; + } + for (std::size_t j = 0; j < t; j++) { + balance_statement += this->out_coins[j].C.inverse(); + } + balance_statement += this->params->get_G()*Scalar(this->f); + if(!schnorr.verify( + balance_statement, + this->balance_proof + )) { + return false; + } + + return true; +} + +// Hash-to-scalar function H_bind +Scalar SpendTransaction::hash_bind( + const std::vector& in_coins, + const std::vector& out_coins, + const uint64_t f, + const std::vector& S1, + const std::vector& C1, + const std::vector& T, + const std::vector& grootle_proofs, + const SchnorrProof& balance_proof, + const BPPlusProof& range_proof +) { + Hash hash(LABEL_HASH_BIND); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + + // Perform the serialization and hashing + stream << in_coins; + stream << out_coins; + stream << f; + stream << S1; + stream << C1; + stream << T; + stream << grootle_proofs; + stream << balance_proof; + stream << range_proof; + hash.include(stream); + + return hash.finalize_scalar(); +} + +} diff --git a/src/libspark/spend_transaction.h b/src/libspark/spend_transaction.h new file mode 100644 index 0000000000..7110058eda --- /dev/null +++ b/src/libspark/spend_transaction.h @@ -0,0 +1,68 @@ +#ifndef FIRO_SPARK_SPEND_TRANSACTION_H +#define FIRO_SPARK_SPEND_TRANSACTION_H +#include "keys.h" +#include "coin.h" +#include "schnorr.h" +#include "util.h" +#include "grootle.h" +#include "bpplus.h" +#include "chaum.h" + +namespace spark { + +using namespace secp_primitives; + +struct InputCoinData { + std::size_t index; // index in cover set + Scalar s; // serial number + GroupElement T; // tag + uint64_t v; // value + Scalar k; // nonce +}; + +struct OutputCoinData { + Address address; + uint64_t v; + std::string memo; +}; + +class SpendTransaction { +public: + SpendTransaction( + const Params* params, + const FullViewKey& full_view_key, + const SpendKey& spend_key, + const std::vector& in_coins, + const std::vector& inputs, + const uint64_t f, + const std::vector& outputs + ); + bool verify(); + + static Scalar hash_bind( + const std::vector& in_coins, + const std::vector& out_coins, + const uint64_t f, + const std::vector& S1, + const std::vector& C1, + const std::vector& T, + const std::vector& grootle_proofs, + const SchnorrProof& balance_proof, + const BPPlusProof& range_proof + ); + +private: + const Params* params; + std::vector in_coins; + std::vector out_coins; + uint64_t f; + std::vector S1, C1, T; + std::vector grootle_proofs; + ChaumProof chaum_proof; + SchnorrProof balance_proof; + BPPlusProof range_proof; +}; + +} + +#endif diff --git a/src/libspark/test/aead_test.cpp b/src/libspark/test/aead_test.cpp new file mode 100644 index 0000000000..975fa495db --- /dev/null +++ b/src/libspark/test/aead_test.cpp @@ -0,0 +1,128 @@ +#include "../aead.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_aead_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(complete) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AEAD_KEY_SIZE); + + // Serialize + int message = 12345; + CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); + ser << message; + + // Encrypt + AEADEncryptedData data = AEAD::encrypt(key, "Associated data", ser); + + // Decrypt + ser = AEAD::decrypt_and_verify(key, "Associated data", data); + + // Deserialize + int message_; + ser >> message_; + + BOOST_CHECK_EQUAL(message_, message); +} + +BOOST_AUTO_TEST_CASE(bad_tag) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AEAD_KEY_SIZE); + + // Serialize and encrypt a message + int message = 12345; + CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); + ser << message; + AEADEncryptedData data = AEAD::encrypt(key, "Associated data", ser); + + // Serialize and encrypt an evil message + ser.clear(); + int evil_message = 666; + ser << evil_message; + AEADEncryptedData evil_data = AEAD::encrypt(key, "Associated data", ser); + + // Replace tag + data.tag = evil_data.tag; + + // Decrypt; this should fail + BOOST_CHECK_THROW(ser = AEAD::decrypt_and_verify(key, "Associated data", data), std::runtime_error); +} + +BOOST_AUTO_TEST_CASE(bad_ciphertext) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AEAD_KEY_SIZE); + + // Serialize and encrypt a message + int message = 12345; + CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); + ser << message; + AEADEncryptedData data = AEAD::encrypt(key, "Associated data", ser); + + // Serialize and encrypt an evil message + ser.clear(); + int evil_message = 666; + ser << evil_message; + AEADEncryptedData evil_data = AEAD::encrypt(key, "Associated data", ser); + + // Replace ciphertext + data.ciphertext = evil_data.ciphertext; + + // Decrypt; this should fail + BOOST_CHECK_THROW(ser = AEAD::decrypt_and_verify(key, "Associated data", data), std::runtime_error); +} + +BOOST_AUTO_TEST_CASE(bad_associated_data) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AEAD_KEY_SIZE); + + // Serialize and encrypt a message + int message = 12345; + CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); + ser << message; + AEADEncryptedData data = AEAD::encrypt(key, "Associated data", ser); + + // Decrypt; this should fail + BOOST_CHECK_THROW(ser = AEAD::decrypt_and_verify(key, "Evil associated data", data), std::runtime_error); +} + +BOOST_AUTO_TEST_CASE(bad_key) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AEAD_KEY_SIZE); + + // Evil key + std::string evil_key_string = "Evil key prefix"; + std::vector evil_key(evil_key_string.begin(), evil_key_string.end()); + evil_key.resize(AEAD_KEY_SIZE); + + // Serialize and encrypt a message + int message = 12345; + CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); + ser << message; + AEADEncryptedData data = AEAD::encrypt(key, "Associated data", ser); + + // Decrypt; this should fail + BOOST_CHECK_THROW(ser = AEAD::decrypt_and_verify(evil_key, "Associated data", data), std::runtime_error); +} + +BOOST_AUTO_TEST_SUITE_END() + +} \ No newline at end of file diff --git a/src/libspark/test/bpplus_test.cpp b/src/libspark/test/bpplus_test.cpp new file mode 100644 index 0000000000..74ddb6f827 --- /dev/null +++ b/src/libspark/test/bpplus_test.cpp @@ -0,0 +1,194 @@ +#include "../bpplus.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_bpplus_tests, BasicTestingSetup) + +// Generate and verify a single aggregated proof +BOOST_AUTO_TEST_CASE(completeness_single) +{ + // Parameters + std::size_t N = 64; // bit length + std::size_t M = 4; // aggregation + + // Generators + GroupElement G, H; + G.randomize(); + H.randomize(); + + std::vector Gi, Hi; + Gi.resize(N*M); + Hi.resize(N*M); + for (std::size_t i = 0; i < N*M; i++) { + Gi[i].randomize(); + Hi[i].randomize(); + } + + // Commitments + std::vector v, r; + v.resize(M); + v[0] = Scalar(uint64_t(0)); + v[1] = Scalar(uint64_t(1)); + v[2] = Scalar(uint64_t(2)); + v[3] = Scalar(std::numeric_limits::max()); + r.resize(M); + std::vector C; + C.resize(M); + for (std::size_t j = 0; j < M; j++) { + r[j].randomize(); + C[j] = G*v[j] + H*r[j]; + } + + BPPlus bpplus(G, H, Gi, Hi, N); + BPPlusProof proof; + bpplus.prove(v, r, C, proof); + + BOOST_CHECK(bpplus.verify(C, proof)); +} + +// A single proof with invalid value +BOOST_AUTO_TEST_CASE(invalid_single) +{ + // Parameters + std::size_t N = 64; // bit length + std::size_t M = 4; // aggregation + + // Generators + GroupElement G, H; + G.randomize(); + H.randomize(); + + std::vector Gi, Hi; + Gi.resize(N*M); + Hi.resize(N*M); + for (std::size_t i = 0; i < N*M; i++) { + Gi[i].randomize(); + Hi[i].randomize(); + } + + // Commitments + std::vector v, r; + v.resize(M); + v[0] = Scalar(uint64_t(0)); + v[1] = Scalar(uint64_t(1)); + v[2] = Scalar(uint64_t(2)); + v[3] = Scalar(std::numeric_limits::max()) + Scalar(uint64_t(1)); // out of range + r.resize(M); + std::vector C; + C.resize(M); + for (std::size_t j = 0; j < M; j++) { + r[j].randomize(); + C[j] = G*v[j] + H*r[j]; + } + + BPPlus bpplus(G, H, Gi, Hi, N); + BPPlusProof proof; + bpplus.prove(v, r, C, proof); + + BOOST_CHECK(!bpplus.verify(C, proof)); +} + +// Generate and verify a batch of proofs with variable aggregation +BOOST_AUTO_TEST_CASE(completeness_batch) +{ + // Parameters + std::size_t N = 64; // bit length + std::size_t B = 4; // number of proofs in batch + + // Generators + GroupElement G, H; + G.randomize(); + H.randomize(); + + std::vector Gi, Hi; + Gi.resize(N*(1 << B)); + Hi.resize(N*(1 << B)); + for (std::size_t i = 0; i < N*(1 << B); i++) { + Gi[i].randomize(); + Hi[i].randomize(); + } + + BPPlus bpplus(G, H, Gi, Hi, N); + std::vector proofs; + proofs.resize(B); + std::vector> C; + + // Build each proof + for (std::size_t i = 0; i < B; i++) { + // Commitments + std::size_t M = 1 << i; + std::vector v, r; + v.resize(M); + r.resize(M); + std::vector C_; + C_.resize(M); + for (std::size_t j = 0; j < M; j++) { + v[j] = Scalar(uint64_t(j)); + r[j].randomize(); + C_[j] = G*v[j] + H*r[j]; + } + C.emplace_back(C_); + + bpplus.prove(v, r, C_, proofs[i]); + } + + BOOST_CHECK(bpplus.verify(C, proofs)); +} + +// An invalid batch of proofs +BOOST_AUTO_TEST_CASE(invalid_batch) +{ + // Parameters + std::size_t N = 64; // bit length + std::size_t B = 4; // number of proofs in batch + + // Generators + GroupElement G, H; + G.randomize(); + H.randomize(); + + std::vector Gi, Hi; + Gi.resize(N*(1 << B)); + Hi.resize(N*(1 << B)); + for (std::size_t i = 0; i < N*(1 << B); i++) { + Gi[i].randomize(); + Hi[i].randomize(); + } + + BPPlus bpplus(G, H, Gi, Hi, N); + std::vector proofs; + proofs.resize(B); + std::vector> C; + + // Build each proof + for (std::size_t i = 0; i < B; i++) { + // Commitments + std::size_t M = 1 << i; + std::vector v, r; + v.resize(M); + r.resize(M); + std::vector C_; + C_.resize(M); + for (std::size_t j = 0; j < M; j++) { + v[j] = Scalar(uint64_t(j)); + // Set one proof to an out-of-range value; + if (i == 0 && j == 0) { + v[j] = Scalar(std::numeric_limits::max()) + Scalar(uint64_t(1)); + } + r[j].randomize(); + C_[j] = G*v[j] + H*r[j]; + } + C.emplace_back(C_); + + bpplus.prove(v, r, C_, proofs[i]); + } + + BOOST_CHECK(!bpplus.verify(C, proofs)); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/chaum_test.cpp b/src/libspark/test/chaum_test.cpp new file mode 100644 index 0000000000..26281438bd --- /dev/null +++ b/src/libspark/test/chaum_test.cpp @@ -0,0 +1,180 @@ +#include "../chaum.h" +#include "../../streams.h" +#include "../../version.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_chaum_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(serialization) +{ + GroupElement F, G, H, U; + F.randomize(); + G.randomize(); + H.randomize(); + U.randomize(); + + const std::size_t n = 3; + + Scalar mu; + mu.randomize(); + std::vector x, y, z; + x.resize(n); + y.resize(n); + z.resize(n); + std::vector S, T; + S.resize(n); + T.resize(n); + for (std::size_t i = 0; i < n; i++) { + x[i].randomize(); + y[i].randomize(); + z[i].randomize(); + + S[i] = F*x[i] + G*y[i] + H*z[i]; + T[i] = (U + G*y[i].negate())*x[i].inverse(); + } + + ChaumProof proof; + + Chaum chaum(F, G, H, U); + chaum.prove(mu, x, y, z, S, T, proof); + + CDataStream serialized(SER_NETWORK, PROTOCOL_VERSION); + serialized << proof; + + ChaumProof deserialized; + serialized >> deserialized; + + BOOST_CHECK(proof.A1 == deserialized.A1); + BOOST_CHECK(proof.t2 == deserialized.t2); + BOOST_CHECK(proof.t3 == deserialized.t3); + for (std::size_t i = 0; i < n; i++) { + BOOST_CHECK(proof.A2[i] == deserialized.A2[i]); + BOOST_CHECK(proof.t1[i] == deserialized.t1[i]); + } +} + +BOOST_AUTO_TEST_CASE(completeness) +{ + GroupElement F, G, H, U; + F.randomize(); + G.randomize(); + H.randomize(); + U.randomize(); + + const std::size_t n = 3; + + Scalar mu; + mu.randomize(); + std::vector x, y, z; + x.resize(n); + y.resize(n); + z.resize(n); + std::vector S, T; + S.resize(n); + T.resize(n); + for (std::size_t i = 0; i < n; i++) { + x[i].randomize(); + y[i].randomize(); + z[i].randomize(); + + S[i] = F*x[i] + G*y[i] + H*z[i]; + T[i] = (U + G*y[i].negate())*x[i].inverse(); + } + + ChaumProof proof; + + Chaum chaum(F, G, H, U); + chaum.prove(mu, x, y, z, S, T, proof); + + BOOST_CHECK(chaum.verify(mu, S, T, proof)); +} + +BOOST_AUTO_TEST_CASE(bad_proofs) +{ + GroupElement F, G, H, U; + F.randomize(); + G.randomize(); + H.randomize(); + U.randomize(); + + const std::size_t n = 3; + + Scalar mu; + mu.randomize(); + std::vector x, y, z; + x.resize(n); + y.resize(n); + z.resize(n); + std::vector S, T; + S.resize(n); + T.resize(n); + for (std::size_t i = 0; i < n; i++) { + x[i].randomize(); + y[i].randomize(); + z[i].randomize(); + + S[i] = F*x[i] + G*y[i] + H*z[i]; + T[i] = (U + G*y[i].negate())*x[i].inverse(); + } + + ChaumProof proof; + + Chaum chaum(F, G, H, U); + chaum.prove(mu, x, y, z, S, T, proof); + + // Bad mu + Scalar evil_mu; + evil_mu.randomize(); + BOOST_CHECK(!(chaum.verify(evil_mu, S, T, proof))); + + // Bad S + for (std::size_t i = 0; i < n; i++) { + std::vector evil_S(S); + evil_S[i].randomize(); + BOOST_CHECK(!(chaum.verify(mu, evil_S, T, proof))); + } + + // Bad T + for (std::size_t i = 0; i < n; i++) { + std::vector evil_T(T); + evil_T[i].randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, evil_T, proof))); + } + + // Bad A1 + ChaumProof evil_proof = proof; + evil_proof.A1.randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, T, evil_proof))); + + // Bad A2 + for (std::size_t i = 0; i < n; i++) { + evil_proof = proof; + evil_proof.A2[i].randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, T, evil_proof))); + } + + // Bad t1 + for (std::size_t i = 0; i < n; i++) { + evil_proof = proof; + evil_proof.t1[i].randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, T, evil_proof))); + } + + // Bad t2 + evil_proof = proof; + evil_proof.t2.randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, T, evil_proof))); + + // Bad t3 + evil_proof = proof; + evil_proof.t3.randomize(); + BOOST_CHECK(!(chaum.verify(mu, S, T, evil_proof))); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/coin_test.cpp b/src/libspark/test/coin_test.cpp new file mode 100644 index 0000000000..567069769b --- /dev/null +++ b/src/libspark/test/coin_test.cpp @@ -0,0 +1,107 @@ +#include "../coin.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +using namespace secp_primitives; + +BOOST_FIXTURE_TEST_SUITE(spark_coin_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(mint_identify_recover) +{ + // Parameters + const Params* params; + params = Params::get_default(); + + const uint64_t i = 12345; + const uint64_t v = 86; + const std::string memo = "Spam and eggs"; + + // Generate keys + SpendKey spend_key(params); + FullViewKey full_view_key(spend_key); + IncomingViewKey incoming_view_key(full_view_key); + + // Generate address + Address address(incoming_view_key, i); + + // Generate coin + Scalar k; + k.randomize(); + Coin coin = Coin( + params, + COIN_TYPE_MINT, + k, + address, + v, + memo + ); + + // Identify coin + IdentifiedCoinData i_data = coin.identify(incoming_view_key); + BOOST_CHECK_EQUAL(i_data.i, i); + BOOST_CHECK_EQUAL_COLLECTIONS(i_data.d.begin(), i_data.d.end(), address.get_d().begin(), address.get_d().end()); + BOOST_CHECK_EQUAL(i_data.v, v); + BOOST_CHECK_EQUAL(i_data.k, k); + BOOST_CHECK_EQUAL(i_data.memo, memo); + + // Recover coin + RecoveredCoinData r_data = coin.recover(full_view_key, i_data); + BOOST_CHECK_EQUAL( + params->get_F()*(SparkUtils::hash_ser(k) + SparkUtils::hash_Q2(incoming_view_key.get_s1(), i) + full_view_key.get_s2()) + full_view_key.get_D(), + params->get_F()*r_data.s + full_view_key.get_D() + ); + BOOST_CHECK_EQUAL(r_data.T*r_data.s + full_view_key.get_D(), params->get_U()); +} + +BOOST_AUTO_TEST_CASE(spend_identify_recover) +{ + // Parameters + const Params* params; + params = Params::get_default(); + + const uint64_t i = 12345; + const uint64_t v = 86; + const std::string memo = "Spam and eggs"; + + // Generate keys + SpendKey spend_key(params); + FullViewKey full_view_key(spend_key); + IncomingViewKey incoming_view_key(full_view_key); + + // Generate address + Address address(incoming_view_key, i); + + // Generate coin + Scalar k; + k.randomize(); + Coin coin = Coin( + params, + COIN_TYPE_SPEND, + k, + address, + v, + memo + ); + + // Identify coin + IdentifiedCoinData i_data = coin.identify(incoming_view_key); + BOOST_CHECK_EQUAL(i_data.i, i); + BOOST_CHECK_EQUAL_COLLECTIONS(i_data.d.begin(), i_data.d.end(), address.get_d().begin(), address.get_d().end()); + BOOST_CHECK_EQUAL(i_data.v, v); + BOOST_CHECK_EQUAL(i_data.k, k); + BOOST_CHECK_EQUAL(i_data.memo, memo); + + // Recover coin + RecoveredCoinData r_data = coin.recover(full_view_key, i_data); + BOOST_CHECK_EQUAL( + params->get_F()*(SparkUtils::hash_ser(k) + SparkUtils::hash_Q2(incoming_view_key.get_s1(), i) + full_view_key.get_s2()) + full_view_key.get_D(), + params->get_F()*r_data.s + full_view_key.get_D() + ); + BOOST_CHECK_EQUAL(r_data.T*r_data.s + full_view_key.get_D(), params->get_U()); +} +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/encrypt_test.cpp b/src/libspark/test/encrypt_test.cpp new file mode 100644 index 0000000000..d0849b81c8 --- /dev/null +++ b/src/libspark/test/encrypt_test.cpp @@ -0,0 +1,52 @@ +#include "../util.h" +#include + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_encrypt_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(complete) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AES256_KEYSIZE); + + // Encrypt + uint64_t i = 12345; + std::vector d = SparkUtils::diversifier_encrypt(key, i); + + // Decrypt + uint64_t i_ = SparkUtils::diversifier_decrypt(key, d); + + BOOST_CHECK_EQUAL(i_, i); +} + +BOOST_AUTO_TEST_CASE(bad_key) +{ + // Key + std::string key_string = "Key prefix"; + std::vector key(key_string.begin(), key_string.end()); + key.resize(AES256_KEYSIZE); + + // Evil key + std::string evil_key_string = "Evil key prefix"; + std::vector evil_key(evil_key_string.begin(), evil_key_string.end()); + evil_key.resize(AES256_KEYSIZE); + + // Encrypt + uint64_t i = 12345; + std::vector d = SparkUtils::diversifier_encrypt(key, i); + + // Decrypt + uint64_t i_ = SparkUtils::diversifier_decrypt(evil_key, d); + + BOOST_CHECK_NE(i_, i); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/grootle_test.cpp b/src/libspark/test/grootle_test.cpp new file mode 100644 index 0000000000..7ac9bfbafd --- /dev/null +++ b/src/libspark/test/grootle_test.cpp @@ -0,0 +1,153 @@ +#include "../grootle.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +static std::vector random_group_vector(const std::size_t n) { + std::vector result; + result.resize(n); + for (std::size_t i = 0; i < n; i++) { + result[i].randomize(); + } + return result; +} + +BOOST_FIXTURE_TEST_SUITE(spark_grootle_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(batch) +{ + // Parameters + const std::size_t n = 4; + const std::size_t m = 3; + const std::size_t N = (std::size_t) std::pow(n, m); // N = 64 + + // Generators + GroupElement H; + H.randomize(); + std::vector Gi = random_group_vector(n*m); + std::vector Hi = random_group_vector(n*m); + + // Commitments + std::size_t commit_size = 60; // require padding + std::vector S = random_group_vector(commit_size); + std::vector V = random_group_vector(commit_size); + + // Generate valid commitments to zero + std::vector indexes = { 0, 1, 3, 59 }; + std::vector sizes = { 60, 60, 59, 16 }; + std::vector S1, V1; + std::vector s, v; + for (std::size_t index : indexes) { + Scalar s_, v_; + s_.randomize(); + v_.randomize(); + s.emplace_back(s_); + v.emplace_back(v_); + + S1.emplace_back(S[index]); + V1.emplace_back(V[index]); + + S[index] += H*s_; + V[index] += H*v_; + } + + // Prepare proving system + Grootle grootle(H, Gi, Hi, n, m); + std::vector proofs; + + for (std::size_t i = 0; i < indexes.size(); i++) { + proofs.emplace_back(); + std::vector S_(S.begin() + commit_size - sizes[i], S.end()); + std::vector V_(V.begin() + commit_size - sizes[i], V.end()); + grootle.prove( + indexes[i] - (commit_size - sizes[i]), + s[i], + S_, + S1[i], + v[i], + V_, + V1[i], + proofs.back() + ); + + // Verify single proof + BOOST_CHECK(grootle.verify(S, S1[i], V, V1[i], sizes[i], proofs.back())); + } + + BOOST_CHECK(grootle.verify(S, S1, V, V1, sizes, proofs)); +} + +BOOST_AUTO_TEST_CASE(invalid_batch) +{ + // Parameters + const std::size_t n = 4; + const std::size_t m = 3; + const std::size_t N = (std::size_t) std::pow(n, m); // N = 64 + + // Generators + GroupElement H; + H.randomize(); + std::vector Gi = random_group_vector(n*m); + std::vector Hi = random_group_vector(n*m); + + // Commitments + std::size_t commit_size = 60; // require padding + std::vector S = random_group_vector(commit_size); + std::vector V = random_group_vector(commit_size); + + // Generate valid commitments to zero + std::vector indexes = { 0, 1, 3, 59 }; + std::vector sizes = { 60, 60, 59, 16 }; + std::vector S1, V1; + std::vector s, v; + for (std::size_t index : indexes) { + Scalar s_, v_; + s_.randomize(); + v_.randomize(); + s.emplace_back(s_); + v.emplace_back(v_); + + S1.emplace_back(S[index]); + V1.emplace_back(V[index]); + + S[index] += H*s_; + V[index] += H*v_; + } + + // Prepare proving system + Grootle grootle(H, Gi, Hi, n, m); + std::vector proofs; + + for (std::size_t i = 0; i < indexes.size(); i++) { + proofs.emplace_back(); + std::vector S_(S.begin() + commit_size - sizes[i], S.end()); + std::vector V_(V.begin() + commit_size - sizes[i], V.end()); + grootle.prove( + indexes[i] - (commit_size - sizes[i]), + s[i], + S_, + S1[i], + v[i], + V_, + V1[i], + proofs.back() + ); + } + + BOOST_CHECK(grootle.verify(S, S1, V, V1, sizes, proofs)); + + // Add an invalid proof + proofs.emplace_back(proofs.back()); + S1.emplace_back(S1.back()); + V1.emplace_back(V1.back()); + S1.back().randomize(); + sizes.emplace_back(sizes.back()); + + BOOST_CHECK(!grootle.verify(S, S1, V, V1, sizes, proofs)); +} + +BOOST_AUTO_TEST_SUITE_END() + +} \ No newline at end of file diff --git a/src/libspark/test/mint_transaction_test.cpp b/src/libspark/test/mint_transaction_test.cpp new file mode 100644 index 0000000000..911d4d9e59 --- /dev/null +++ b/src/libspark/test/mint_transaction_test.cpp @@ -0,0 +1,44 @@ +#include "../mint_transaction.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +using namespace secp_primitives; + +BOOST_FIXTURE_TEST_SUITE(spark_mint_transaction_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(generate_verify) +{ + // Parameters + const Params* params; + params = Params::get_default(); + + const uint64_t i = 12345; + const uint64_t v = 86; + const std::string memo = "Spam and eggs"; + + // Generate keys + SpendKey spend_key(params); + FullViewKey full_view_key(spend_key); + IncomingViewKey incoming_view_key(full_view_key); + + // Generate address + Address address(incoming_view_key, i); + + // Generate mint transaction + MintTransaction t( + params, + address, + v, + memo + ); + + // Verify + BOOST_CHECK(t.verify()); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/schnorr_test.cpp b/src/libspark/test/schnorr_test.cpp new file mode 100644 index 0000000000..ea350ccaaf --- /dev/null +++ b/src/libspark/test/schnorr_test.cpp @@ -0,0 +1,85 @@ +#include "../schnorr.h" +#include "../../streams.h" +#include "../../version.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_schnorr_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(serialization) +{ + GroupElement G; + G.randomize(); + + Scalar y; + y.randomize(); + GroupElement Y = G*y; + + SchnorrProof proof; + + Schnorr schnorr(G); + schnorr.prove(y, Y, proof); + + CDataStream serialized(SER_NETWORK, PROTOCOL_VERSION); + serialized << proof; + + SchnorrProof deserialized; + serialized >> deserialized; + + BOOST_CHECK(proof.c == deserialized.c); + BOOST_CHECK(proof.t == deserialized.t); +} + +BOOST_AUTO_TEST_CASE(completeness) +{ + GroupElement G; + G.randomize(); + + Scalar y; + y.randomize(); + GroupElement Y = G*y; + + SchnorrProof proof; + + Schnorr schnorr(G); + schnorr.prove(y, Y, proof); + + BOOST_CHECK(schnorr.verify(Y, proof)); +} + +BOOST_AUTO_TEST_CASE(bad_proofs) +{ + GroupElement G; + G.randomize(); + + Scalar y; + y.randomize(); + GroupElement Y = G*y; + + SchnorrProof proof; + + Schnorr schnorr(G); + schnorr.prove(y, Y, proof); + + // Bad Y + GroupElement evil_Y; + evil_Y.randomize(); + BOOST_CHECK(!(schnorr.verify(evil_Y, proof))); + + // Bad c + SchnorrProof evil_proof = proof; + evil_proof.c.randomize(); + BOOST_CHECK(!(schnorr.verify(Y, evil_proof))); + + // Bad t + evil_proof = proof; + evil_proof.t.randomize(); + BOOST_CHECK(!(schnorr.verify(Y, evil_proof))); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/spend_transaction_test.cpp b/src/libspark/test/spend_transaction_test.cpp new file mode 100644 index 0000000000..a8cda7fc20 --- /dev/null +++ b/src/libspark/test/spend_transaction_test.cpp @@ -0,0 +1,110 @@ +#include "../spend_transaction.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +using namespace secp_primitives; + +BOOST_FIXTURE_TEST_SUITE(spark_spend_transaction_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(generate_verify) +{ + // Parameters + const Params* params; + params = Params::get_test(); + + const std::string memo = "Spam and eggs"; // arbitrary memo + + // Generate keys + SpendKey spend_key(params); + FullViewKey full_view_key(spend_key); + IncomingViewKey incoming_view_key(full_view_key); + + // Generate address + const uint64_t i = 12345; + Address address(incoming_view_key, i); + + // Mint some coins to the address + std::size_t N = (std::size_t) pow(params->get_n_grootle(), params->get_m_grootle()); + std::vector in_coins; + for (std::size_t i = 0; i < N; i++) { + Scalar k; + k.randomize(); + + uint64_t v = 12 + i; // arbitrary value + + in_coins.emplace_back(Coin( + params, + COIN_TYPE_MINT, + k, + address, + v, + memo + )); + } + + // Track values so we can set the fee to make the transaction balance + uint64_t f = 0; + + // Choose coins to spend, recover them, and prepare them for spending + std::vector spend_indices = { 1, 3, 5 }; + std::vector spend_coin_data; + const std::size_t w = spend_indices.size(); + for (std::size_t u = 0; u < w; u++) { + IdentifiedCoinData identified_coin_data = in_coins[spend_indices[u]].identify(incoming_view_key); + RecoveredCoinData recovered_coin_data = in_coins[spend_indices[u]].recover(full_view_key, identified_coin_data); + + spend_coin_data.emplace_back(); + spend_coin_data.back().index = spend_indices[u]; + spend_coin_data.back().k = identified_coin_data.k; + spend_coin_data.back().s = recovered_coin_data.s; + spend_coin_data.back().T = recovered_coin_data.T; + spend_coin_data.back().v = identified_coin_data.v; + + f -= identified_coin_data.v; + } + + // Generate new output coins and compute the fee + const std::size_t t = 2; + std::vector out_coin_data; + for (std::size_t j = 0; j < t; j++) { + out_coin_data.emplace_back(); + out_coin_data.back().address = address; + out_coin_data.back().v = 123 + j; // arbitrary value + out_coin_data.back().memo = memo; + + f += out_coin_data.back().v; + } + + // Assert the fee is correct + uint64_t fee_test = f; + for (std::size_t u = 0; u < w; u++) { + fee_test += spend_coin_data[u].v; + } + for (std::size_t j = 0; j < t; j++) { + fee_test -= out_coin_data[j].v; + } + if (fee_test != 0) { + throw std::runtime_error("Bad fee assertion"); + } + + // Generate spend transaction + SpendTransaction transaction( + params, + full_view_key, + spend_key, + in_coins, + spend_coin_data, + f, + out_coin_data + ); + + // Verify + BOOST_CHECK(transaction.verify()); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/test/transcript_test.cpp b/src/libspark/test/transcript_test.cpp new file mode 100644 index 0000000000..4ef9e1131d --- /dev/null +++ b/src/libspark/test/transcript_test.cpp @@ -0,0 +1,174 @@ +#include "../transcript.h" + +#include "../../test/test_bitcoin.h" +#include + +namespace spark { + +BOOST_FIXTURE_TEST_SUITE(spark_transcript_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(init) +{ + // Identical domain separators + Transcript transcript_1("Spam"); + Transcript transcript_2("Spam"); + BOOST_CHECK_EQUAL(transcript_1.challenge("x"), transcript_2.challenge("x")); + + // Distinct domain separators + transcript_1 = Transcript("Spam"); + transcript_2 = Transcript("Eggs"); + BOOST_CHECK_NE(transcript_1.challenge("x"), transcript_2.challenge("x")); +} + +BOOST_AUTO_TEST_CASE(challenge_labels) +{ + Transcript transcript_1("Spam"); + Transcript transcript_2("Spam"); + + // Identical challenge labels + BOOST_CHECK_EQUAL(transcript_1.challenge("x"), transcript_2.challenge("x")); + + // Distinct challenge labels + BOOST_CHECK_NE(transcript_1.challenge("x"), transcript_2.challenge("y")); +} + +BOOST_AUTO_TEST_CASE(add_types) +{ + // Add all fixed types and assert distinct challenges + const std::string domain = "Spam"; + Transcript transcript(domain); + + Scalar scalar; + scalar.randomize(); + transcript.add("Scalar", scalar); + Scalar ch_1 = transcript.challenge("x"); + + GroupElement group; + group.randomize(); + transcript.add("Group", group); + Scalar ch_2 = transcript.challenge("x"); + BOOST_CHECK_NE(ch_1, ch_2); + + std::vector scalars; + for (std::size_t i = 0; i < 3; i++) { + scalar.randomize(); + scalars.emplace_back(scalar); + } + Scalar ch_3 = transcript.challenge("x"); + BOOST_CHECK_NE(ch_2, ch_3); + + std::vector groups; + for (std::size_t i = 0; i < 3; i++) { + group.randomize(); + groups.emplace_back(group); + } + Scalar ch_4 = transcript.challenge("x"); + BOOST_CHECK_NE(ch_3, ch_4); + + const std::string data = "Arbitrary string"; + const std::vector data_char(data.begin(), data.end()); + transcript.add("Data", data_char); + Scalar ch_5 = transcript.challenge("x"); + BOOST_CHECK_NE(ch_4, ch_5); +} + +BOOST_AUTO_TEST_CASE(repeated_challenge) +{ + // Repeated challenges must be distinct, even with the same label + Transcript transcript("Eggs"); + + Scalar ch_1 = transcript.challenge("x"); + Scalar ch_2 = transcript.challenge("x"); + + BOOST_CHECK_NE(ch_1, ch_2); +} + +BOOST_AUTO_TEST_CASE(repeated_challenge_ordering) +{ + // Repeated challenges must respect ordering + Transcript prover("Spam"); + Transcript verifier("Spam"); + + Scalar prover_x = prover.challenge("x"); + Scalar prover_y = prover.challenge("y"); + + // Oh no, we mixed up the order + Scalar verifier_y = verifier.challenge("y"); + Scalar verifier_x = verifier.challenge("x"); + + BOOST_CHECK_NE(prover_x, verifier_x); + BOOST_CHECK_NE(prover_y, verifier_y); +} + +BOOST_AUTO_TEST_CASE(identical_transcripts) +{ + // Ensure that identical transcripts yield identical challenges + Transcript prover("Beer"); + Transcript verifier("Beer"); + + Scalar scalar; + scalar.randomize(); + GroupElement group; + group.randomize(); + + prover.add("Scalar", scalar); + verifier.add("Scalar", scalar); + prover.add("Group", group); + verifier.add("Group", group); + + BOOST_CHECK_EQUAL(prover.challenge("x"), verifier.challenge("x")); +} + +BOOST_AUTO_TEST_CASE(distinct_values) +{ + // Ensure that distinct transcript values yield distinct challenges + Transcript prover("Soda"); + Transcript verifier("Soda"); + + Scalar prover_scalar; + prover_scalar.randomize(); + Scalar verifier_scalar; + verifier_scalar.randomize(); + + prover.add("Scalar", prover_scalar); + verifier.add("Scalar", verifier_scalar); + + BOOST_CHECK_NE(prover.challenge("x"), verifier.challenge("x")); +} + +BOOST_AUTO_TEST_CASE(distinct_labels) +{ + // Ensure that distinct transcript labels yield distinct challenges + Transcript prover("Soda"); + Transcript verifier("Soda"); + + Scalar scalar; + scalar.randomize(); + + prover.add("Prover scalar", scalar); + verifier.add("Verifier scalar", scalar); + + BOOST_CHECK_NE(prover.challenge("x"), verifier.challenge("y")); +} + +BOOST_AUTO_TEST_CASE(converging) +{ + // Transcripts with distinct initial states but common post-challenge elements + Transcript transcript_1("Spam"); + Transcript transcript_2("Eggs"); + + Scalar ch_1 = transcript_1.challenge("x"); + Scalar ch_2 = transcript_1.challenge("x"); + + // Add a common element and assert the states still differ + Scalar scalar; + scalar.randomize(); + transcript_1.add("Scalar", scalar); + transcript_2.add("Scalar", scalar); + + BOOST_CHECK_NE(transcript_1.challenge("x"), transcript_2.challenge("x")); +} + +BOOST_AUTO_TEST_SUITE_END() + +} diff --git a/src/libspark/transcript.cpp b/src/libspark/transcript.cpp new file mode 100644 index 0000000000..238e2d3ae9 --- /dev/null +++ b/src/libspark/transcript.cpp @@ -0,0 +1,177 @@ +#include "transcript.h" + +namespace spark { + +using namespace secp_primitives; + +// Flags for transcript operations +const unsigned char FLAG_DOMAIN = 0; +const unsigned char FLAG_DATA = 1; +const unsigned char FLAG_VECTOR = 2; +const unsigned char FLAG_CHALLENGE = 3; + +// Initialize a transcript with a domain separator +Transcript::Transcript(const std::string domain) { + // Prepare the state + this->ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(this->ctx, EVP_blake2b512(), NULL); + + // Write the protocol and mode information + std::vector protocol(LABEL_PROTOCOL.begin(), LABEL_PROTOCOL.end()); + EVP_DigestUpdate(this->ctx, protocol.data(), protocol.size()); + EVP_DigestUpdate(this->ctx, &HASH_MODE_TRANSCRIPT, sizeof(HASH_MODE_TRANSCRIPT)); + + // Domain separator + include_flag(FLAG_DOMAIN); + include_label(domain); +} + +Transcript::~Transcript() { + EVP_MD_CTX_free(this->ctx); +} + +Transcript& Transcript::operator=(const Transcript& t) { + if (this == &t) { + return *this; + } + + EVP_MD_CTX_copy_ex(this->ctx, t.ctx); + + return *this; +} + +// Add a group element +void Transcript::add(const std::string label, const GroupElement& group_element) { + std::vector data; + data.resize(GroupElement::serialize_size); + group_element.serialize(data.data()); + + include_flag(FLAG_DATA); + include_label(label); + include_data(data); +} + +// Add a vector of group elements +void Transcript::add(const std::string label, const std::vector& group_elements) { + include_flag(FLAG_VECTOR); + size(group_elements.size()); + include_label(label); + for (std::size_t i = 0; i < group_elements.size(); i++) { + std::vector data; + data.resize(GroupElement::serialize_size); + group_elements[i].serialize(data.data()); + include_data(data); + } +} + +// Add a scalar +void Transcript::add(const std::string label, const Scalar& scalar) { + std::vector data; + data.resize(SCALAR_ENCODING); + scalar.serialize(data.data()); + + include_flag(FLAG_DATA); + include_label(label); + include_data(data); +} + +// Add a vector of scalars +void Transcript::add(const std::string label, const std::vector& scalars) { + include_flag(FLAG_VECTOR); + size(scalars.size()); + include_label(label); + for (std::size_t i = 0; i < scalars.size(); i++) { + std::vector data; + data.resize(SCALAR_ENCODING); + scalars[i].serialize(data.data()); + include_data(data); + } +} + +// Add arbitrary data +void Transcript::add(const std::string label, const std::vector& data) { + include_flag(FLAG_DATA); + include_label(label); + include_data(data); +} + +// Produce a challenge +Scalar Transcript::challenge(const std::string label) { + // Ensure we can properly populate a scalar + if (EVP_MD_size(EVP_blake2b512()) < SCALAR_ENCODING) { + throw std::runtime_error("Bad hash size!"); + } + + std::vector hash; + hash.resize(EVP_MD_size(EVP_blake2b512())); + unsigned char counter = 0; + + EVP_MD_CTX* state_counter; + state_counter = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_counter, EVP_blake2b512(), NULL); + + EVP_MD_CTX* state_finalize; + state_finalize = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_finalize, EVP_blake2b512(), NULL); + + include_flag(FLAG_CHALLENGE); + include_label(label); + + while (1) { + // Prepare temporary state for counter testing + EVP_MD_CTX_copy_ex(state_counter, this->ctx); + + // Embed the counter + EVP_DigestUpdate(state_counter, &counter, sizeof(counter)); + + // Finalize the hash with a temporary state + EVP_MD_CTX_copy_ex(state_finalize, state_counter); + unsigned int TEMP; // We already know the digest length! + EVP_DigestFinal_ex(state_finalize, hash.data(), &TEMP); + + // Check for scalar validity + Scalar candidate; + try { + candidate.deserialize(hash.data()); + EVP_MD_CTX_copy_ex(this->ctx, state_counter); + + EVP_MD_CTX_free(state_counter); + EVP_MD_CTX_free(state_finalize); + + return candidate; + } catch (...) { + counter++; + } + } +} + +// Encode and include a size +void Transcript::size(const std::size_t size_) { + Scalar size_scalar(size_); + std::vector size_data; + size_data.resize(SCALAR_ENCODING); + size_scalar.serialize(size_data.data()); + EVP_DigestUpdate(this->ctx, size_data.data(), size_data.size()); +} + +// Include a flag +void Transcript::include_flag(const unsigned char flag) { + EVP_DigestUpdate(this->ctx, &flag, sizeof(flag)); +} + +// Encode and include a label +void Transcript::include_label(const std::string label) { + std::vector bytes(label.begin(), label.end()); + include_data(bytes); +} + +// Encode and include data +void Transcript::include_data(const std::vector& data) { + // Include size + size(data.size()); + + // Include data + EVP_DigestUpdate(this->ctx, data.data(), data.size()); +} + +} diff --git a/src/libspark/transcript.h b/src/libspark/transcript.h new file mode 100644 index 0000000000..eef2f9f59b --- /dev/null +++ b/src/libspark/transcript.h @@ -0,0 +1,32 @@ +#ifndef FIRO_SPARK_TRANSCRIPT_H +#define FIRO_SPARK_TRANSCRIPT_H +#include +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +class Transcript { +public: + Transcript(const std::string); + Transcript& operator=(const Transcript&); + ~Transcript(); + void add(const std::string, const Scalar&); + void add(const std::string, const std::vector&); + void add(const std::string, const GroupElement&); + void add(const std::string, const std::vector&); + void add(const std::string, const std::vector&); + Scalar challenge(const std::string); + +private: + void size(const std::size_t size_); + void include_flag(const unsigned char); + void include_label(const std::string); + void include_data(const std::vector&); + EVP_MD_CTX* ctx; +}; + +} + +#endif diff --git a/src/libspark/util.cpp b/src/libspark/util.cpp new file mode 100644 index 0000000000..1f58f97e69 --- /dev/null +++ b/src/libspark/util.cpp @@ -0,0 +1,232 @@ +#include "util.h" + +namespace spark { + +using namespace secp_primitives; + +// Encrypt a diversifier using AES-256 +std::vector SparkUtils::diversifier_encrypt(const std::vector& key, const uint64_t i) { + // Serialize the diversifier + CDataStream i_stream(SER_NETWORK, PROTOCOL_VERSION); + i_stream << i; + + // Assert proper sizes + if (key.size() != AES256_KEYSIZE) { + throw std::invalid_argument("Bad diversifier encryption key size"); + } + + // Encrypt using padded AES-256 (CBC) using a zero IV + std::vector ciphertext; + ciphertext.resize(AES_BLOCKSIZE); + std::vector iv; + iv.resize(AES_BLOCKSIZE); + + AES256CBCEncrypt aes(key.data(), iv.data(), true); + aes.Encrypt(reinterpret_cast(i_stream.data()), i_stream.size(), ciphertext.data()); + + return ciphertext; +} + +// Decrypt a diversifier using AES-256 +uint64_t SparkUtils::diversifier_decrypt(const std::vector& key, const std::vector& d) { + // Assert proper sizes + if (key.size() != AES256_KEYSIZE) { + throw std::invalid_argument("Bad diversifier decryption key size"); + } + + // Decrypt using padded AES-256 (CBC) using a zero IV + CDataStream i_stream(SER_NETWORK, PROTOCOL_VERSION); + i_stream.resize(sizeof(uint64_t)); + + std::vector iv; + iv.resize(AES_BLOCKSIZE); + + AES256CBCDecrypt aes(key.data(), iv.data(), true); + aes.Decrypt(d.data(), d.size(), reinterpret_cast(i_stream.data())); + + // Deserialize the diversifier + uint64_t i; + i_stream >> i; + + return i; +} + +// Produce a uniformly-sampled group element from a label +GroupElement SparkUtils::hash_generator(const std::string label) { + const int GROUP_ENCODING = 34; + const unsigned char ZERO = 0; + + // Ensure we can properly populate a + if (EVP_MD_size(EVP_blake2b512()) < GROUP_ENCODING) { + throw std::runtime_error("Bad hash size!"); + } + + EVP_MD_CTX* ctx; + ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, EVP_blake2b512(), NULL); + + // Write the protocol and mode + std::vector protocol(LABEL_PROTOCOL.begin(), LABEL_PROTOCOL.end()); + EVP_DigestUpdate(ctx, protocol.data(), protocol.size()); + EVP_DigestUpdate(ctx, &HASH_MODE_GROUP_GENERATOR, sizeof(HASH_MODE_GROUP_GENERATOR)); + + // Write the label + std::vector bytes(label.begin(), label.end()); + EVP_DigestUpdate(ctx, bytes.data(), bytes.size()); + + std::vector hash; + hash.resize(EVP_MD_size(EVP_blake2b512())); + unsigned char counter = 0; + + EVP_MD_CTX* state_counter; + state_counter = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_counter, EVP_blake2b512(), NULL); + + EVP_MD_CTX* state_finalize; + state_finalize = EVP_MD_CTX_new(); + EVP_DigestInit_ex(state_finalize, EVP_blake2b512(), NULL); + + // Finalize the hash + while (1) { + // Prepare temporary state for counter testing + EVP_MD_CTX_copy_ex(state_counter, ctx); + + // Embed the counter + EVP_DigestUpdate(state_counter, &counter, sizeof(counter)); + + // Finalize the hash with a temporary state + EVP_MD_CTX_copy_ex(state_finalize, state_counter); + unsigned int TEMP; // We already know the digest length! + EVP_DigestFinal_ex(state_finalize, hash.data(), &TEMP); + + // Assemble the serialized input: + // bytes 0..31: x coordinate + // byte 32: even/odd + // byte 33: zero (this point is not infinity) + unsigned char candidate_bytes[GROUP_ENCODING]; + memcpy(candidate_bytes, hash.data(), 33); + memcpy(candidate_bytes + 33, &ZERO, 1); + GroupElement candidate; + try { + candidate.deserialize(candidate_bytes); + + // Deserialization can succeed even with an invalid result + if (!candidate.isMember()) { + counter++; + continue; + } + + EVP_MD_CTX_free(ctx); + EVP_MD_CTX_free(state_counter); + EVP_MD_CTX_free(state_finalize); + + return candidate; + } catch (...) { + counter++; + } + } +} + +// Derive an AES key for diversifier encryption/decryption +std::vector SparkUtils::kdf_diversifier(const Scalar& s1) { + KDF kdf(LABEL_KDF_DIVERSIFIER); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << s1; + kdf.include(stream); + + return kdf.finalize(AES256_KEYSIZE); +} + +// Derive a ChaCha20 key for AEAD operations +std::vector SparkUtils::kdf_aead(const GroupElement& K_der) { + KDF kdf(LABEL_KDF_AEAD); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << K_der; + kdf.include(stream); + + return kdf.finalize(AEAD_KEY_SIZE); +} + +// Hash-to-group function H_div +GroupElement SparkUtils::hash_div(const std::vector& d) { + Hash hash(LABEL_HASH_DIV); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << d; + hash.include(stream); + + return hash.finalize_group(); +} + +// Hash-to-scalar function H_Q2 +Scalar SparkUtils::hash_Q2(const Scalar& s1, const Scalar& i) { + Hash hash(LABEL_HASH_Q2); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << s1; + stream << i; + hash.include(stream); + + return hash.finalize_scalar(); +} + +// Hash-to-scalar function H_k +Scalar SparkUtils::hash_k(const Scalar& k) { + Hash hash(LABEL_HASH_K); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << k; + hash.include(stream); + + return hash.finalize_scalar(); +} + +// Hash-to-scalar function H_ser +Scalar SparkUtils::hash_ser(const Scalar& k) { + Hash hash(LABEL_HASH_SER); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << k; + hash.include(stream); + + return hash.finalize_scalar(); +} + +// Hash-to-scalar function H_val +Scalar SparkUtils::hash_val(const Scalar& k) { + Hash hash(LABEL_HASH_VAL); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << k; + hash.include(stream); + + return hash.finalize_scalar(); +} + +// Hash-to-scalar function H_ser1 +Scalar SparkUtils::hash_ser1(const Scalar& s, const GroupElement& D) { + Hash hash(LABEL_HASH_SER1); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << s; + stream << D; + hash.include(stream); + + return hash.finalize_scalar(); +} + +// Hash-to-scalar function H_val1 +Scalar SparkUtils::hash_val1(const Scalar& s, const GroupElement& D) { + Hash hash(LABEL_HASH_VAL1); + + CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + stream << s; + stream << D; + hash.include(stream); + + return hash.finalize_scalar(); +} + +} diff --git a/src/libspark/util.h b/src/libspark/util.h new file mode 100644 index 0000000000..ff37e8de56 --- /dev/null +++ b/src/libspark/util.h @@ -0,0 +1,84 @@ +#ifndef FIRO_SPARK_UTIL_H +#define FIRO_SPARK_UTIL_H +#include +#include +#include "../../crypto/aes.h" +#include "../streams.h" +#include "../version.h" +#include "../util.h" +#include "kdf.h" +#include "hash.h" +#include "grootle_proof.h" +#include "schnorr_proof.h" + +namespace spark { + +using namespace secp_primitives; + +// Useful serialization constant +const std::size_t SCALAR_ENCODING = 32; + +// Base protocol separator +const std::string LABEL_PROTOCOL = "SPARK"; + +// All hash operations have a mode flag to separate their use cases +const unsigned char HASH_MODE_TRANSCRIPT = 0; // a Fiat-Shamir transcript +const unsigned char HASH_MODE_GROUP_GENERATOR = 1; // a prime-order group generator derived from a label +const unsigned char HASH_MODE_FUNCTION = 2; // a hash function derived from a label +const unsigned char HASH_MODE_KDF = 3; // a key derivation function derived from a label + +// Generator labels +const std::string LABEL_GENERATOR_F = "F"; +const std::string LABEL_GENERATOR_G = "G"; +const std::string LABEL_GENERATOR_H = "H"; +const std::string LABEL_GENERATOR_U = "U"; +const std::string LABEL_GENERATOR_G_RANGE = "G_RANGE"; +const std::string LABEL_GENERATOR_H_RANGE = "H_RANGE"; +const std::string LABEL_GENERATOR_G_GROOTLE = "G_GROOTLE"; +const std::string LABEL_GENERATOR_H_GROOTLE = "H_GROOTLE"; + +// Hash function labels +const std::string LABEL_HASH_DIV = "DIV"; +const std::string LABEL_HASH_Q2 = "Q2"; +const std::string LABEL_HASH_K = "K"; +const std::string LABEL_HASH_SER = "SER"; +const std::string LABEL_HASH_VAL = "VAL"; +const std::string LABEL_HASH_SER1 = "SER1"; +const std::string LABEL_HASH_VAL1 = "VAL1"; +const std::string LABEL_HASH_BIND = "BIND"; + +// KDF labels +const std::string LABEL_KDF_DIVERSIFIER = "DIVERSIFIER"; +const std::string LABEL_KDF_AEAD = "AEAD"; + +// AEAD constants +const int AEAD_IV_SIZE = 12; // byte length of the IV +const int AEAD_KEY_SIZE = 32; // byte length of the key +const int AEAD_TAG_SIZE = 16; // byte length of the tag + +class SparkUtils { +public: + // Protocol-level hash functions + static GroupElement hash_generator(const std::string label); + + // Hash functions + static GroupElement hash_div(const std::vector& d); + static Scalar hash_Q2(const Scalar& s1, const Scalar& i); + static Scalar hash_k(const Scalar& k); + static Scalar hash_ser(const Scalar& k); + static Scalar hash_val(const Scalar& k); + static Scalar hash_ser1(const Scalar& s, const GroupElement& D); + static Scalar hash_val1(const Scalar& s, const GroupElement& D); + + // Key derivation functions + static std::vector kdf_diversifier(const Scalar& s1); + static std::vector kdf_aead(const GroupElement& K_der); + + // Diversifier encryption/decryption + static std::vector diversifier_encrypt(const std::vector& key, const uint64_t i); + static uint64_t diversifier_decrypt(const std::vector& key, const std::vector& d); +}; + +} + +#endif