Skip to content

Commit

Permalink
Implement peer handling for bloom filters
Browse files Browse the repository at this point in the history
  • Loading branch information
tdittr committed Nov 4, 2023
1 parent 47140a9 commit fd6586c
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 24 deletions.
5 changes: 5 additions & 0 deletions ntp-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ mod exports {
KeyExchangeClient, KeyExchangeError, KeyExchangeResult, KeyExchangeServer, NtsRecord,
NtsRecordDecoder, WriteError,
};

#[cfg(feature = "ntpv5")]
pub mod v5 {
pub use crate::packet::v5::server_reference_id::{BloomFilter, ServerId};
}
}

#[cfg(feature = "__internal-api")]
Expand Down
7 changes: 7 additions & 0 deletions ntp-proto/src/packet/extension_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::{

use crate::keyset::DecodedServerCookie;

#[cfg(feature = "ntpv5")]
use crate::packet::v5::extension_fields::{ReferenceIdRequest, ReferenceIdResponse};

use super::{crypto::EncryptResult, error::ParsingError, Cipher, CipherProvider, Mac};

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -550,6 +553,10 @@ impl<'a> ExtensionField<'a> {
TypeId::DraftIdentification => {
EF::decode_draft_identification(message, extension_header_version)
}
#[cfg(feature = "ntpv5")]
TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message)?.into()),
#[cfg(feature = "ntpv5")]
TypeId::ReferenceIdResponse => Ok(ReferenceIdResponse::decode(message).into()),
type_id => EF::decode_unknown(type_id.to_type_id(), message),
}
}
Expand Down
27 changes: 23 additions & 4 deletions ntp-proto/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod extension_fields;
mod mac;

#[cfg(feature = "ntpv5")]
mod v5;
pub mod v5;

pub use crypto::{
AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError,
Expand Down Expand Up @@ -118,15 +118,15 @@ pub struct NtpPacket<'a> {
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum NtpHeader {
pub enum NtpHeader {
V3(NtpHeaderV3V4),
V4(NtpHeaderV3V4),
#[cfg(feature = "ntpv5")]
V5(v5::NtpHeaderV5),
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct NtpHeaderV3V4 {
pub struct NtpHeaderV3V4 {
leap: NtpLeapIndicator,
mode: NtpAssociationMode,
stratum: u8,
Expand Down Expand Up @@ -652,7 +652,14 @@ impl<'a> NtpPacket<'a> {
.untrusted
.into_iter()
.chain(input.efdata.authenticated)
.filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_)))
.filter_map(|ef| match ef {
uid @ ExtensionField::UniqueIdentifier(_) => Some(uid),
ExtensionField::ReferenceIdRequest(req) => {
let response = req.to_response(&system.bloom_filter)?;
Some(ExtensionField::ReferenceIdResponse(response).into_owned())
}
_ => None,
})
.chain(std::iter::once(ExtensionField::DraftIdentification(
Cow::Borrowed(v5::DRAFT_VERSION),
)))
Expand Down Expand Up @@ -871,6 +878,10 @@ impl<'a> NtpPacket<'a> {
}
}

pub fn header(&self) -> NtpHeader {
self.header
}

pub fn leap(&self) -> NtpLeapIndicator {
match self.header {
NtpHeader::V3(header) => header.leap,
Expand Down Expand Up @@ -1039,6 +1050,10 @@ impl<'a> NtpPacket<'a> {
}
}
}

pub fn untrusted_extension_fields(&self) -> impl Iterator<Item = &ExtensionField> {
self.efdata.untrusted.iter()
}
}

// Returns whether all uid extension fields found match the given uid, or
Expand Down Expand Up @@ -1167,6 +1182,10 @@ impl<'a> NtpPacket<'a> {
NtpHeader::V5(ref mut header) => header.root_dispersion = root_dispersion,
}
}

pub fn push_untrusted(&mut self, ef: ExtensionField<'static>) {
self.efdata.untrusted.push(ef);
}
}

impl<'a> Default for NtpPacket<'a> {
Expand Down
45 changes: 41 additions & 4 deletions ntp-proto/src/packet/v5/extension_fields.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::packet::error::ParsingError;
use crate::packet::v5::server_reference_id::BloomFilter;
use crate::ExtensionField;
use std::borrow::Cow;
use std::convert::Infallible;
use std::io::Write;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -108,24 +111,40 @@ impl ReferenceIdRequest {
}

pub fn serialize(&self, mut writer: impl Write) -> std::io::Result<()> {
let payload_len = self.payload_len;
let ef_len: u16 = payload_len + 4;

writer.write_all(&Type::ReferenceIdRequest.to_bits().to_be_bytes())?;
writer.write_all(&ef_len.to_be_bytes())?;
writer.write_all(&self.offset.to_be_bytes())?;
writer.write_all(&[0; 2])?;

let words = self.payload_len / 4;
assert_eq!(self.payload_len % 4, 0);
let words = payload_len / 4;
assert_eq!(payload_len % 4, 0);

for _ in 1..words {
writer.write_all(&[0; 4])?;
}

Ok(())
}

pub(crate) fn offset(&self) -> u16 {
pub fn decode(msg: &[u8]) -> Result<Self, ParsingError<Infallible>> {
let payload_len =
u16::try_from(msg.len()).expect("NTP fields can not be longer than u16::MAX");
let offset_bytes: [u8; 2] = msg[0..2].try_into().unwrap();

Ok(Self {
payload_len,
offset: u16::from_be_bytes(offset_bytes),
})
}

pub fn offset(&self) -> u16 {
self.offset
}

pub(crate) fn payload_len(&self) -> u16 {
pub fn payload_len(&self) -> u16 {
self.payload_len
}
}
Expand Down Expand Up @@ -168,11 +187,29 @@ impl<'a> ReferenceIdResponse<'a> {
Ok(())
}

pub fn decode(bytes: &'a [u8]) -> Self {
Self {
bytes: Cow::Borrowed(bytes),
}
}

pub fn bytes(&self) -> &[u8] {
&*self.bytes
}
}

impl From<ReferenceIdRequest> for ExtensionField<'static> {
fn from(value: ReferenceIdRequest) -> Self {
Self::ReferenceIdRequest(value)
}
}

impl<'a> From<ReferenceIdResponse<'a>> for ExtensionField<'a> {
fn from(value: ReferenceIdResponse<'a>) -> Self {
Self::ReferenceIdResponse(value)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
57 changes: 47 additions & 10 deletions ntp-proto/src/packet/v5/server_reference_id.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::packet::v5::extension_fields::{ReferenceIdRequest, ReferenceIdResponse};
use crate::packet::v5::NtpClientCookie;
use rand::distributions::{Distribution, Standard};
use rand::Rng;
use rand::{thread_rng, Rng};
use std::array::from_fn;
use std::fmt::{Debug, Formatter};

#[derive(Copy, Clone, Debug)]
struct U12(u16);
Expand Down Expand Up @@ -41,7 +42,7 @@ impl TryFrom<u16> for U12 {
}
}

#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
pub struct ServerId([U12; 10]);

impl ServerId {
Expand All @@ -54,7 +55,13 @@ impl ServerId {
}
}

#[derive(Clone, Eq, PartialEq, Debug)]
impl Default for ServerId {
fn default() -> Self {
Self::new(&mut thread_rng())
}
}

#[derive(Copy, Clone, Eq, PartialEq)]
pub struct BloomFilter([u8; Self::BYTES]);
impl BloomFilter {
pub const BYTES: usize = 512;
Expand Down Expand Up @@ -110,6 +117,25 @@ impl<'a> FromIterator<&'a BloomFilter> for BloomFilter {
}
}

impl Default for BloomFilter {
fn default() -> Self {
Self::new()
}
}

impl Debug for BloomFilter {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let str: String = self
.0
.chunks_exact(32)
.map(|chunk| chunk.iter().fold(0, |acc, b| acc | b))
.map(|b| char::from_u32(0x2800 + b as u32).unwrap())
.collect();

f.debug_tuple("BloomFilter").field(&str).finish()
}
}

pub struct RemoteBloomFilter {
filter: BloomFilter,
chunk_size: u16,
Expand Down Expand Up @@ -167,7 +193,7 @@ impl RemoteBloomFilter {
pub fn handle_response(
&mut self,
cookie: NtpClientCookie,
response: ReferenceIdResponse,
response: &ReferenceIdResponse,
) -> Result<(), ResponseHandlingError> {
let Some((offset, expected_cookie)) = self.last_requested else {
return Err(ResponseHandlingError::NotAwaitingResponse);
Expand Down Expand Up @@ -199,6 +225,17 @@ impl RemoteBloomFilter {
}
}

impl Debug for RemoteBloomFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RemoteBloomFilter")
.field("chunk_size", &self.chunk_size)
.field("last_requested", &self.last_requested)
.field("next_to_request", &self.next_to_request)
.field("is_filled", &self.is_filled)
.finish()
}
}

#[derive(Debug, Copy, Clone)]
pub enum ResponseHandlingError {
NotAwaitingResponse,
Expand Down Expand Up @@ -285,7 +322,7 @@ mod tests {
assert!(matches!(
bf.handle_response(
NtpClientCookie::new_random(),
ReferenceIdResponse::new(&[0u8; 16]).unwrap()
&ReferenceIdResponse::new(&[0u8; 16]).unwrap()
),
Err(NotAwaitingResponse)
));
Expand All @@ -296,18 +333,18 @@ mod tests {
assert_eq!(req.payload_len(), chunk_size);

assert!(matches!(
bf.handle_response(cookie, ReferenceIdResponse::new(&[0; 24]).unwrap()),
bf.handle_response(cookie, &ReferenceIdResponse::new(&[0; 24]).unwrap()),
Err(MismatchedLength)
));

let mut wrong_cookie = cookie;
wrong_cookie.0[0] ^= 0xFF; // Flip all bits in first byte
assert!(matches!(
bf.handle_response(wrong_cookie, ReferenceIdResponse::new(&[0; 16]).unwrap()),
bf.handle_response(wrong_cookie, &ReferenceIdResponse::new(&[0; 16]).unwrap()),
Err(MismatchedCookie)
));

bf.handle_response(cookie, ReferenceIdResponse::new(&[1; 16]).unwrap())
bf.handle_response(cookie, &ReferenceIdResponse::new(&[1; 16]).unwrap())
.unwrap();
assert_eq!(bf.next_to_request, 16);
assert_eq!(bf.last_requested, None);
Expand All @@ -323,7 +360,7 @@ mod tests {
assert!(bf.full_filter().is_none());
let bytes: Vec<_> = (0..req.payload_len()).map(|_| chunk as u8 + 1).collect();
let response = ReferenceIdResponse::new(&bytes).unwrap();
bf.handle_response(cookie, response).unwrap();
bf.handle_response(cookie, &response).unwrap();
}

assert_eq!(bf.next_to_request, 0);
Expand All @@ -346,7 +383,7 @@ mod tests {
let cookie = NtpClientCookie::new_random();
let request = bf.next_request(cookie);
let response = request.to_response(&target_filter).unwrap();
bf.handle_response(cookie, response).unwrap();
bf.handle_response(cookie, &response).unwrap();
}

let result_filter = bf.full_filter().unwrap();
Expand Down
Loading

0 comments on commit fd6586c

Please sign in to comment.