Skip to content

Commit

Permalink
fix up the encoding, make this look closer to what the basic filter w…
Browse files Browse the repository at this point in the history
…as doing
  • Loading branch information
josibake committed Jan 29, 2024
1 parent a6023c8 commit 81fcdbf
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 46 deletions.
53 changes: 46 additions & 7 deletions src/blockfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Distributed under the MIT software license, see the accompanying
// file COPYING or http://www.opensource.org/licenses/mit-license.php.

#include <cstddef>
#include <mutex>
#include <set>

Expand Down Expand Up @@ -146,6 +147,45 @@ bool GCSFilter::MatchAny(const ElementSet& elements) const
return MatchInternal(queries.data(), queries.size());
}

InputsFilter::InputsFilter()
: m_N(0), m_encoded{0}
{}

InputsFilter::InputsFilter(const ElementSet& elements)
{
size_t N = elements.size();
m_N = static_cast<uint32_t>(N);
if (m_N != N) {
throw std::invalid_argument("N must be <2^32");
}
VectorWriter stream{m_encoded, 0};
WriteCompactSize(stream, m_N * 33);
if (elements.empty()) {
return;
}
for (const Element& e: elements) {
stream << e;
}
}

InputsFilter::InputsFilter(std::vector<unsigned char> encoded_filter, bool skip_decode_check)
: m_encoded(std::move(encoded_filter))
{
SpanReader stream{m_encoded};

uint64_t N = ReadCompactSize(stream);
m_N = static_cast<uint32_t>(N) / 33;
if (m_N != (N / 33)) {
throw std::ios_base::failure("N must be <2^32");
}
if (skip_decode_check) return;
// Verify that the encoded filter contains exactly N elements. If it has too much or too little
// data, a std::ios_base::failure exception will be raised.
if (stream.size() != m_N * 33) {
throw std::ios_base::failure("encoded_filter contains excess data");
}
}

const std::string& BlockFilterTypeName(BlockFilterType filter_type)
{
static std::string unknown_retval;
Expand Down Expand Up @@ -208,12 +248,12 @@ static GCSFilter::ElementSet BasicFilterElements(const CBlock& block,
return elements;
}

static InputsFilter SilentPaymentFilterElements(const CBlock& block,
static InputsFilter::ElementSet SilentPaymentFilterElements(const CBlock& block,
const CBlockUndo& block_undo)
{
if (block_undo.vtxundo.empty()) return InputsFilter();
if (block.vtx.size() == 1) return InputsFilter();
InputsFilter elements;
InputsFilter::ElementSet elements;
if (block_undo.vtxundo.empty()) return elements;
if (block.vtx.size() == 1) return elements;
assert(block.vtx.size() - 1 == block_undo.vtxundo.size());
for (uint32_t i = 0; i < block.vtx.size(); ++i) {
const CTransactionRef& tx = block.vtx.at(i);
Expand All @@ -237,9 +277,8 @@ static InputsFilter SilentPaymentFilterElements(const CBlock& block,
inputs_hash.Set(tweak_data->first.begin(), tweak_data->first.end(), true);
CPubKey input_pubkeys_sum{tweak_data->second};
CPubKey final{inputs_hash.UnhashedECDH(input_pubkeys_sum)};
elements.addElement(final.data());
elements.emplace(*final.begin(), *final.end());
}
elements.Encode();
return elements;
}

Expand All @@ -257,7 +296,7 @@ BlockFilter::BlockFilter(BlockFilterType filter_type, const uint256& block_hash,
m_filter = GCSFilter(params, std::move(filter), skip_decode_check);
break;
case BlockFilterType::SILENT_PAYMENTS:
m_filter = InputsFilter(std::move(filter));
m_filter = InputsFilter(std::move(filter), skip_decode_check);
break;
case BlockFilterType::INVALID:
throw std::invalid_argument("unknown filter_type");
Expand Down
55 changes: 16 additions & 39 deletions src/blockfilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,50 +90,27 @@ class GCSFilter

class InputsFilter {
public:
InputsFilter() = default;
InputsFilter(std::vector<unsigned char> filter)
: m_encoded(filter)
{}
// Element structure
struct Element {
char data[33];

Element(const unsigned char* d) {
std::memcpy(data, d, 33);
}
Element();
};
typedef std::vector<unsigned char> Element;
typedef std::unordered_set<Element, ByteVectorHash> ElementSet;

// Adds an element to the filter
void addElement(const Element& element) {
elements.push_back(element);
}
private:
uint32_t m_N; //!< Number of elements in the filter
std::vector<unsigned char> m_encoded; // Encoded block of bytes

// Encodes the elements into m_encoded
void Encode() {
m_encoded.clear();
for (const auto& element : elements) {
m_encoded.insert(m_encoded.end(), element.data, element.data + 33);
}
}
public:
/** Constructs an empty filter. */
explicit InputsFilter();

// Returns the encoded data
const std::vector<unsigned char>& GetEncoded() const LIFETIMEBOUND { return m_encoded; }
/** Reconstructs an already-created filter from an encoding. */
InputsFilter(std::vector<unsigned char> encoded_filter, bool skip_decode_check);

private:
std::vector<Element> elements; // Stores the elements
std::vector<unsigned char> m_encoded; // Encoded block of bytes
/** Builds a new filter from the params and set of elements. */
InputsFilter(const ElementSet& elements);

// decode into elements from a encoded InputsFilter
void decode() {
elements.clear();
const size_t elementSize = 33;
for (size_t i = 0; i < m_encoded.size(); i += elementSize) {
Element element;
std::memcpy(element.data, &m_encoded[i], elementSize);
elements.push_back(element);
}
}
uint32_t GetN() const { return m_N; }

/** Returns the encoded data */
const std::vector<unsigned char>& GetEncoded() const LIFETIMEBOUND { return m_encoded; }
};

using Filter = std::variant<GCSFilter, InputsFilter>;
Expand Down

0 comments on commit 81fcdbf

Please sign in to comment.