Skip to content

Commit

Permalink
Add a state machine for protocol upgrades to the peer
Browse files Browse the repository at this point in the history
  • Loading branch information
tdittr committed Oct 31, 2023
1 parent c5219c2 commit 4a2bc7d
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 28 deletions.
26 changes: 19 additions & 7 deletions ntp-proto/src/packet/extension_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,25 @@ impl<'a> ExtensionField<'a> {
#[cfg(feature = "ntpv5")]
fn decode_draft_identification(
message: &'a [u8],
extension_header_version: ExtensionHeaderVersion,
) -> Result<Self, ParsingError<std::convert::Infallible>> {
let di = match core::str::from_utf8(message) {
Ok(di) if di.is_ascii() => di,
_ => return Err(super::v5::V5Error::InvalidDraftIdentification.into()),
};

let di = match extension_header_version {
ExtensionHeaderVersion::V4 => di.trim_end_matches('\0'),
ExtensionHeaderVersion::V5 => di,
};

Ok(ExtensionField::DraftIdentification(Cow::Borrowed(di)))
}

fn decode(raw: RawExtensionField<'a>) -> Result<Self, ParsingError<std::convert::Infallible>> {
fn decode(
raw: RawExtensionField<'a>,
extension_header_version: ExtensionHeaderVersion,
) -> Result<Self, ParsingError<std::convert::Infallible>> {
type EF<'a> = ExtensionField<'a>;
type TypeId = ExtensionFieldTypeId;

Expand All @@ -466,7 +475,9 @@ impl<'a> ExtensionField<'a> {
TypeId::NtsCookie => EF::decode_nts_cookie(message),
TypeId::NtsCookiePlaceholder => EF::decode_nts_cookie_placeholder(message),
#[cfg(feature = "ntpv5")]
TypeId::DraftIdentification => EF::decode_draft_identification(message),
TypeId::DraftIdentification => {
EF::decode_draft_identification(message, extension_header_version)
}
type_id => EF::decode_unknown(type_id.to_type_id(), message),
}
}
Expand Down Expand Up @@ -617,7 +628,8 @@ impl<'a> ExtensionFieldData<'a> {
efdata.authenticated.append(&mut efdata.untrusted);
}
_ => {
let field = ExtensionField::decode(field).map_err(|e| e.generalize())?;
let field =
ExtensionField::decode(field, version).map_err(|e| e.generalize())?;
efdata.untrusted.push(field);
}
}
Expand Down Expand Up @@ -701,7 +713,7 @@ impl<'a> RawEncryptedField<'a> {
// TODO: Discuss whether we want this check
Err(ParsingError::MalformedNtsExtensionFields)
} else {
Ok(ExtensionField::decode(encrypted_field)
Ok(ExtensionField::decode(encrypted_field, version)
.map_err(|e| e.generalize())?
.into_owned())
}
Expand Down Expand Up @@ -931,15 +943,15 @@ mod tests {
type_id: ExtensionFieldTypeId::NtsCookiePlaceholder,
message_bytes: &[1; COOKIE_LENGTH],
};
let output = ExtensionField::decode(raw).unwrap_err();
let output = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap_err();

assert!(matches!(output, ParsingError::MalformedCookiePlaceholder));

let raw = RawExtensionField {
type_id: ExtensionFieldTypeId::NtsCookiePlaceholder,
message_bytes: &[0; COOKIE_LENGTH],
};
let output = ExtensionField::decode(raw).unwrap();
let output = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap();

let ExtensionField::NtsCookiePlaceholder { cookie_length } = output else {
panic!("incorrect variant");
Expand Down Expand Up @@ -972,7 +984,7 @@ mod tests {
data.extend(&[0]); // Padding

let raw = RawExtensionField::deserialize(&data, 4, ExtensionHeaderVersion::V5).unwrap();
let ef = ExtensionField::decode(raw).unwrap();
let ef = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap();

let ExtensionField::DraftIdentification(ref parsed) = ef else {
panic!("Unexpected extensionfield {ef:?}... expected DraftIdentification");
Expand Down
64 changes: 45 additions & 19 deletions ntp-proto/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,27 @@ impl<'a> NtpPacket<'a> {
)
}

#[cfg(feature = "ntpv5")]
pub fn poll_message_upgrade_request(poll_interval: PollInterval) -> (Self, RequestIdentifier) {
let (mut header, id) = NtpHeaderV3V4::poll_message(poll_interval);

header.reference_timestamp = v5::UPGRADE_TIMESTAMP;
let draft_id = ExtensionField::DraftIdentification(Cow::Borrowed(v5::DRAFT_VERSION));

(
NtpPacket {
header: NtpHeader::V4(header),
efdata: ExtensionFieldData {
authenticated: vec![],
encrypted: vec![],
untrusted: vec![draft_id],
},
mac: None,
},
id,
)
}

#[cfg(feature = "ntpv5")]
pub fn poll_message_v5(poll_interval: PollInterval) -> (Self, RequestIdentifier) {
let (header, id) = v5::NtpHeaderV5::poll_message(poll_interval);
Expand Down Expand Up @@ -560,6 +581,7 @@ impl<'a> NtpPacket<'a> {
NtpHeader::V4(header) => {
let mut response_header =
NtpHeaderV3V4::timestamp_response(system, header, recv_timestamp, clock);
let mut extra_ef = None;

#[cfg(feature = "ntpv5")]
{
Expand All @@ -569,6 +591,9 @@ impl<'a> NtpPacket<'a> {
(header.reference_timestamp, input.draft_id())
{
response_header.reference_timestamp = v5::UPGRADE_TIMESTAMP;
extra_ef = Some(ExtensionField::DraftIdentification(Cow::Borrowed(
v5::DRAFT_VERSION,
)));
};
}

Expand All @@ -584,6 +609,7 @@ impl<'a> NtpPacket<'a> {
.into_iter()
.chain(input.efdata.authenticated)
.filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_)))
.chain(extra_ef)
.collect(),
},
mac: None,
Expand Down Expand Up @@ -937,6 +963,20 @@ impl<'a> NtpPacket<'a> {
self.is_kiss() && self.reference_id().is_ntsn()
}

#[cfg(feature = "ntpv5")]
pub fn is_upgrade(&self) -> bool {
match (self.header, self.draft_id()) {

Check failure on line 968 in ntp-proto/src/packet/mod.rs

View workflow job for this annotation

GitHub Actions / Clippy (x86_64-unknown-linux-gnu, false, true)

match expression looks like `matches!` macro

Check failure on line 968 in ntp-proto/src/packet/mod.rs

View workflow job for this annotation

GitHub Actions / Clippy (x86_64-apple-darwin, true, false, -target x86_64-macos-gnu -g)

match expression looks like `matches!` macro
(
NtpHeader::V4(NtpHeaderV3V4 {
reference_timestamp: v5::UPGRADE_TIMESTAMP,
..
}),
Some(v5::DRAFT_VERSION),
) => true,
_ => false,
}
}

pub fn valid_server_response(&self, identifier: RequestIdentifier, nts_enabled: bool) -> bool {
if let Some(uid) = identifier.uid {
let auth = check_uid_extensionfield(self.efdata.authenticated.iter(), &uid);
Expand Down Expand Up @@ -1508,21 +1548,10 @@ mod tests {
assert!(!response.valid_server_response(id, true));
}

#[cfg(feature = "ntpv5")]
#[test]
fn v5_upgrade_packet() {
let (mut packet, _) = NtpPacket::poll_message(PollInterval::default());
let NtpHeader::V4(header) = &mut packet.header else {
panic!("wrong version");
};
header.reference_timestamp = NtpTimestamp::from_fixed_int(0x4E5450354E545035);

#[cfg(feature = "ntpv5")]
packet
.efdata
.untrusted
.push(ExtensionField::DraftIdentification(Cow::Borrowed(
v5::DRAFT_VERSION,
)));
let (packet, _) = NtpPacket::poll_message_upgrade_request(PollInterval::default());

let response = NtpPacket::timestamp_response(
&SystemSnapshot::default(),
Expand All @@ -1537,13 +1566,10 @@ mod tests {
panic!("wrong version");
};

let expect = if cfg!(feature = "ntpv5") {
assert_eq!(
header.reference_timestamp,
NtpTimestamp::from_fixed_int(0x4E5450354E545035)
} else {
NtpTimestamp::from_fixed_int(0)
};

assert_eq!(header.reference_timestamp, expect);
);
}

#[test]
Expand Down
Loading

0 comments on commit 4a2bc7d

Please sign in to comment.