diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ed9e7b07b..ed868dc74 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -29,6 +29,7 @@ jobs: - "" features: - "" + - "--features ntpv5" steps: - name: Checkout sources uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 @@ -67,6 +68,7 @@ jobs: os: [ubuntu-latest] features: - "" + - "--features ntpv5" steps: - name: Checkout sources uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 @@ -96,6 +98,7 @@ jobs: - "" features: - "" + - "--features ntpv5" steps: - name: Checkout sources uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 @@ -243,7 +246,7 @@ jobs: uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 with: command: clippy - args: --workspace --all-targets -- -D warnings + args: --workspace --all-targets --all-features -- -D warnings - name: Run clippy (fuzzers) uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 with: @@ -287,7 +290,7 @@ jobs: TARGET_CC: "/home/runner/.cargo/bin/cargo-zigbuild zig cc -- -target arm-linux-gnueabihf -mcpu=generic+v7a+vfp3-d32+thumb2-neon -g" with: command: clippy - args: --target armv7-unknown-linux-gnueabihf --workspace --all-targets -- -D warnings + args: --target armv7-unknown-linux-gnueabihf --workspace --all-targets --all-features -- -D warnings clippy-macos: name: ClippyMacOS @@ -321,7 +324,7 @@ jobs: TARGET_CC: "/home/runner/.cargo/bin/cargo-zigbuild zig cc -- -target x86_64-macos-gnu -g" with: command: clippy - args: --target x86_64-apple-darwin --workspace --all-targets -- -D warnings + args: --target x86_64-apple-darwin --workspace --all-targets --all-features -- -D warnings clippy-musl: name: ClippyMusl @@ -355,11 +358,16 @@ jobs: TARGET_CC: "/home/runner/.cargo/bin/cargo-zigbuild zig cc -- -target x86_64-linux-musl" with: command: clippy - args: --target x86_64-unknown-linux-musl --workspace --all-targets -- -D warnings + args: --target x86_64-unknown-linux-musl --workspace --all-targets --all-features -- -D warnings fuzz: name: Smoke-test fuzzing targets runs-on: ubuntu-20.04 + strategy: + matrix: + features: + - "" + - "--features ntpv5" steps: - name: Checkout sources uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 @@ -377,9 +385,9 @@ jobs: tool: cargo-fuzz - name: Smoke-test fuzz targets run: | - cargo fuzz build - for target in $(cargo fuzz list) ; do - cargo fuzz run $target -- -max_total_time=10 + cargo fuzz build ${{ matrix.features }} + for target in $(cargo fuzz list ${{ matrix.features }}) ; do + cargo fuzz run ${{ matrix.features }} $target -- -max_total_time=10 done audit-dependencies: diff --git a/Cargo.lock b/Cargo.lock index e8dbe9746..09125c734 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ checksum = "a2e1373abdaa212b704512ec2bd8b26bd0b7d5c3f70117411a5d9a451383c859" [[package]] name = "async-trait" -version = "0.1.73" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", @@ -473,9 +473,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] @@ -673,9 +673,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b21f559e07218024e7e9f90f96f601825397de0e25420135f7f952453fed0b" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -704,9 +704,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.37" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", @@ -807,11 +807,10 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" dependencies = [ - "cfg-if", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -819,9 +818,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -830,9 +829,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] @@ -1033,9 +1032,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.15" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" +checksum = "a3b801d0e0a6726477cc207f60162da452f3a95adb368399bef20a946e06f65c" dependencies = [ "memchr", ] diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 60b068037..42e799595 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -74,12 +74,6 @@ dependencies = [ "syn", ] -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - [[package]] name = "backtrace" version = "0.3.69" @@ -241,12 +235,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dtoa" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" - [[package]] name = "equivalent" version = "1.0.1" @@ -358,16 +346,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "lock_api" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" version = "0.4.20" @@ -411,7 +389,7 @@ dependencies = [ [[package]] name = "ntp-os-clock" -version = "0.3.6" +version = "1.0.0" dependencies = [ "libc", "ntp-proto", @@ -420,7 +398,7 @@ dependencies = [ [[package]] name = "ntp-proto" -version = "0.3.6" +version = "1.0.0" dependencies = [ "aead", "aes-siv", @@ -446,7 +424,7 @@ dependencies = [ [[package]] name = "ntp-udp" -version = "0.3.6" +version = "1.0.0" dependencies = [ "libc", "ntp-proto", @@ -457,14 +435,13 @@ dependencies = [ [[package]] name = "ntpd" -version = "0.3.6" +version = "1.0.0" dependencies = [ "async-trait", "libc", "ntp-os-clock", "ntp-proto", "ntp-udp", - "prometheus-client", "rand 0.8.5", "rustls", "rustls-native-certs", @@ -525,29 +502,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - [[package]] name = "pin-project-lite" version = "0.2.12" @@ -569,29 +523,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "prometheus-client" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c99afa9a01501019ac3a14d71d9f94050346f55ca471ce90c799a15c58f61e2" -dependencies = [ - "dtoa", - "itoa", - "parking_lot", - "prometheus-client-derive-encode", -] - -[[package]] -name = "prometheus-client-derive-encode" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "quote" version = "1.0.33" @@ -635,15 +566,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "redox_syscall" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" -dependencies = [ - "bitflags", -] - [[package]] name = "ring" version = "0.16.20" @@ -723,12 +645,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "sct" version = "0.7.0" @@ -811,12 +727,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "smallvec" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" - [[package]] name = "socket2" version = "0.5.3" diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 0fe015e28..c5382c115 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -8,6 +8,9 @@ publish = false [package.metadata] cargo-fuzz = true +[features] +ntpv5 = ["ntp-proto/ntpv5"] + [dependencies] rand = "0.8.5" diff --git a/fuzz/fuzz_targets/encrypted_server_parsing.rs b/fuzz/fuzz_targets/encrypted_server_parsing.rs index a3fe60966..526a8ccd9 100644 --- a/fuzz/fuzz_targets/encrypted_server_parsing.rs +++ b/fuzz/fuzz_targets/encrypted_server_parsing.rs @@ -6,7 +6,9 @@ use std::{ }; use libfuzzer_sys::fuzz_target; -use ntp_proto::{test_cookie, EncryptResult, ExtensionField, KeySetProvider, NtpPacket}; +use ntp_proto::{ + test_cookie, EncryptResult, ExtensionField, ExtensionHeaderVersion, KeySetProvider, NtpPacket, +}; use rand::{rngs::StdRng, set_thread_rng, SeedableRng}; const fn next_multiple_of(lhs: u16, rhs: u16) -> u16 { @@ -16,7 +18,14 @@ const fn next_multiple_of(lhs: u16, rhs: u16) -> u16 { } } -fuzz_target!(|parts: (Vec, Vec, Vec, Vec, u64)| { +fuzz_target!(|parts: ( + Vec, + Vec, + Vec, + Vec, + u64, + ExtensionHeaderVersion +)| { set_thread_rng(StdRng::seed_from_u64(parts.4)); // Can't test reencoding because of the keyset @@ -29,7 +38,11 @@ fuzz_target!(|parts: (Vec, Vec, Vec, Vec, u64)| { let _ = cursor.write_all(&parts.0); let cookie = test_cookie(); let enc_cookie = keyset.encode_cookie_pub(&cookie); - let _ = ExtensionField::NtsCookie(Cow::Borrowed(&enc_cookie)).serialize_pub(&mut cursor, 4); + let _ = ExtensionField::NtsCookie(Cow::Borrowed(&enc_cookie)).serialize_pub( + &mut cursor, + 4, + parts.5, + ); let _ = cursor.write_all(&parts.1); let mut ciphertext = parts.2.clone(); diff --git a/ntp-proto/Cargo.toml b/ntp-proto/Cargo.toml index 934b9dfb0..3fc13d1ab 100644 --- a/ntp-proto/Cargo.toml +++ b/ntp-proto/Cargo.toml @@ -16,6 +16,7 @@ rust-version.workspace = true __internal-fuzz = ["arbitrary", "__internal-api"] __internal-test = ["__internal-api"] __internal-api = [] +ntpv5 = [] [dependencies] # Note: md5 is needed to calculate ReferenceIDs for IPv6 addresses per RFC5905 diff --git a/ntp-proto/src/lib.rs b/ntp-proto/src/lib.rs index 0add3f99b..f18ffc209 100644 --- a/ntp-proto/src/lib.rs +++ b/ntp-proto/src/lib.rs @@ -43,8 +43,8 @@ mod exports { #[cfg(feature = "__internal-fuzz")] pub use super::packet::ExtensionField; pub use super::packet::{ - Cipher, CipherProvider, EncryptResult, NoCipher, NtpAssociationMode, NtpLeapIndicator, - NtpPacket, PacketParsingError, + Cipher, CipherProvider, EncryptResult, ExtensionHeaderVersion, NoCipher, + NtpAssociationMode, NtpLeapIndicator, NtpPacket, PacketParsingError, }; #[cfg(feature = "__internal-fuzz")] pub use super::peer::fuzz_measurement_from_packet; diff --git a/ntp-proto/src/packet/crypto.rs b/ntp-proto/src/packet/crypto.rs index fbb1abe21..3d2b29bd6 100644 --- a/ntp-proto/src/packet/crypto.rs +++ b/ntp-proto/src/packet/crypto.rs @@ -5,7 +5,7 @@ use zeroize::{Zeroize, ZeroizeOnDrop}; use crate::keyset::DecodedServerCookie; -use super::extensionfields::ExtensionField; +use super::extension_fields::ExtensionField; #[derive(Debug, thiserror::Error)] #[error("Could not decrypt ciphertext")] diff --git a/ntp-proto/src/packet/error.rs b/ntp-proto/src/packet/error.rs index 094813832..b50b34b99 100644 --- a/ntp-proto/src/packet/error.rs +++ b/ntp-proto/src/packet/error.rs @@ -10,6 +10,8 @@ pub enum ParsingError { MalformedNonce, MalformedCookiePlaceholder, DecryptError(T), + #[cfg(feature = "ntpv5")] + V5(super::v5::V5Error), } impl ParsingError { @@ -23,6 +25,8 @@ impl ParsingError { MalformedNonce => Err(MalformedNonce), MalformedCookiePlaceholder => Err(MalformedCookiePlaceholder), DecryptError(decrypt_error) => Ok(decrypt_error), + #[cfg(feature = "ntpv5")] + V5(e) => Err(V5(e)), } } } @@ -38,6 +42,8 @@ impl ParsingError { MalformedNonce => MalformedNonce, MalformedCookiePlaceholder => MalformedCookiePlaceholder, DecryptError(decrypt_error) => match decrypt_error {}, + #[cfg(feature = "ntpv5")] + V5(e) => V5(e), } } } @@ -53,6 +59,8 @@ impl Display for ParsingError { Self::MalformedNonce => f.write_str("Malformed nonce (likely invalid length)"), Self::MalformedCookiePlaceholder => f.write_str("Malformed cookie placeholder"), Self::DecryptError(_) => f.write_str("Failed to decrypt NTS extension fields"), + #[cfg(feature = "ntpv5")] + Self::V5(e) => Display::fmt(e, f), } } } diff --git a/ntp-proto/src/packet/extensionfields.rs b/ntp-proto/src/packet/extension_fields.rs similarity index 73% rename from ntp-proto/src/packet/extensionfields.rs rename to ntp-proto/src/packet/extension_fields.rs index 3bdbd7ef6..4c9ae10f7 100644 --- a/ntp-proto/src/packet/extensionfields.rs +++ b/ntp-proto/src/packet/extension_fields.rs @@ -13,7 +13,11 @@ enum ExtensionFieldTypeId { NtsCookie, NtsCookiePlaceholder, NtsEncryptedField, - Unknown { type_id: u16 }, + Unknown { + type_id: u16, + }, + #[cfg(feature = "ntpv5")] + DraftIdentification, } impl ExtensionFieldTypeId { @@ -23,6 +27,8 @@ impl ExtensionFieldTypeId { 0x204 => Self::NtsCookie, 0x304 => Self::NtsCookiePlaceholder, 0x404 => Self::NtsEncryptedField, + #[cfg(feature = "ntpv5")] + 0xF5FF => Self::DraftIdentification, _ => Self::Unknown { type_id }, } } @@ -33,6 +39,8 @@ impl ExtensionFieldTypeId { ExtensionFieldTypeId::NtsCookie => 0x204, ExtensionFieldTypeId::NtsCookiePlaceholder => 0x304, ExtensionFieldTypeId::NtsEncryptedField => 0x404, + #[cfg(feature = "ntpv5")] + ExtensionFieldTypeId::DraftIdentification => 0xF5FF, ExtensionFieldTypeId::Unknown { type_id } => type_id, } } @@ -42,9 +50,16 @@ impl ExtensionFieldTypeId { pub enum ExtensionField<'a> { UniqueIdentifier(Cow<'a, [u8]>), NtsCookie(Cow<'a, [u8]>), - NtsCookiePlaceholder { cookie_length: u16 }, + NtsCookiePlaceholder { + cookie_length: u16, + }, InvalidNtsEncryptedField, - Unknown { type_id: u16, data: Cow<'a, [u8]> }, + #[cfg(feature = "ntpv5")] + DraftIdentification(Cow<'a, str>), + Unknown { + type_id: u16, + data: Cow<'a, [u8]>, + }, } impl<'a> std::fmt::Debug for ExtensionField<'a> { @@ -59,6 +74,10 @@ impl<'a> std::fmt::Debug for ExtensionField<'a> { .field("body_length", body_length) .finish(), Self::InvalidNtsEncryptedField => f.debug_struct("InvalidNtsEncryptedField").finish(), + #[cfg(feature = "ntpv5")] + Self::DraftIdentification(arg0) => { + f.debug_tuple("DraftIdentification").field(arg0).finish() + } Self::Unknown { type_id: typeid, data, @@ -92,22 +111,35 @@ impl<'a> ExtensionField<'a> { cookie_length: body_length, }, InvalidNtsEncryptedField => InvalidNtsEncryptedField, + #[cfg(feature = "ntpv5")] + DraftIdentification(data) => DraftIdentification(Cow::Owned(data.into_owned())), } } - fn serialize(&self, w: &mut W, minimum_size: u16) -> std::io::Result<()> { + fn serialize( + &self, + w: &mut W, + minimum_size: u16, + version: ExtensionHeaderVersion, + ) -> std::io::Result<()> { use ExtensionField::*; match self { - Unknown { type_id, data } => Self::encode_unknown(w, *type_id, data, minimum_size), + Unknown { type_id, data } => { + Self::encode_unknown(w, *type_id, data, minimum_size, version) + } UniqueIdentifier(identifier) => { - Self::encode_unique_identifier(w, identifier, minimum_size) + Self::encode_unique_identifier(w, identifier, minimum_size, version) } - NtsCookie(cookie) => Self::encode_nts_cookie(w, cookie, minimum_size), + NtsCookie(cookie) => Self::encode_nts_cookie(w, cookie, minimum_size, version), NtsCookiePlaceholder { cookie_length: body_length, - } => Self::encode_nts_cookie_placeholder(w, *body_length, minimum_size), + } => Self::encode_nts_cookie_placeholder(w, *body_length, minimum_size, version), InvalidNtsEncryptedField => Err(std::io::ErrorKind::Other.into()), + #[cfg(feature = "ntpv5")] + DraftIdentification(data) => { + Self::encode_draft_identification(w, data, minimum_size, version) + } } } @@ -116,8 +148,9 @@ impl<'a> ExtensionField<'a> { &self, w: &mut W, minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { - self.serialize(w, minimum_size) + self.serialize(w, minimum_size, version) } fn encode_framing( @@ -125,6 +158,7 @@ impl<'a> ExtensionField<'a> { ef_id: ExtensionFieldTypeId, data_length: usize, minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { if data_length > u16::MAX as usize - 4 { return Err(std::io::Error::new( @@ -135,9 +169,12 @@ impl<'a> ExtensionField<'a> { // u16 for the type_id, u16 for the length let header_width = 4; + let mut actual_length = (data_length as u16 + header_width).max(minimum_size); + + if version == ExtensionHeaderVersion::V4 { + actual_length = next_multiple_of_u16(actual_length, 4) + } - let actual_length = - next_multiple_of_u16((data_length as u16 + header_width).max(minimum_size), 4); w.write_all(&ef_id.to_type_id().to_be_bytes())?; w.write_all(&actual_length.to_be_bytes()) } @@ -180,12 +217,14 @@ impl<'a> ExtensionField<'a> { w: &mut W, identifier: &[u8], minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( w, ExtensionFieldTypeId::UniqueIdentifier, identifier.len(), minimum_size, + version, )?; w.write_all(identifier)?; Self::encode_padding(w, identifier.len(), minimum_size) @@ -195,12 +234,14 @@ impl<'a> ExtensionField<'a> { w: &mut W, cookie: &[u8], minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( w, ExtensionFieldTypeId::NtsCookie, cookie.len(), minimum_size, + version, )?; w.write_all(cookie)?; @@ -214,12 +255,14 @@ impl<'a> ExtensionField<'a> { w: &mut W, cookie_length: u16, minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( w, ExtensionFieldTypeId::NtsCookiePlaceholder, cookie_length as usize, minimum_size, + version, )?; Self::write_zeros(w, cookie_length)?; @@ -234,12 +277,14 @@ impl<'a> ExtensionField<'a> { type_id: u16, data: &[u8], minimum_size: u16, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( w, ExtensionFieldTypeId::Unknown { type_id }, data.len(), minimum_size, + version, )?; w.write_all(data)?; @@ -253,6 +298,7 @@ impl<'a> ExtensionField<'a> { w: &mut Cursor<&mut [u8]>, fields_to_encrypt: &[ExtensionField], cipher: &dyn Cipher, + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { let padding = [0; 4]; @@ -271,7 +317,7 @@ impl<'a> ExtensionField<'a> { // RFC 8915, section 5.5: contrary to the RFC 7822 requirement that fields have a minimum length of 16 or 28 octets, // encrypted extension fields MAY be arbitrarily short (but still MUST be a multiple of 4 octets in length) let minimum_size = 0; - field.serialize(w, minimum_size)?; + field.serialize(w, minimum_size, version)?; } let plaintext_length = w.position() - plaintext_start; @@ -335,6 +381,28 @@ impl<'a> ExtensionField<'a> { Ok(()) } + #[cfg(feature = "ntpv5")] + fn encode_draft_identification( + w: &mut impl Write, + data: &str, + minimum_size: u16, + version: ExtensionHeaderVersion, + ) -> std::io::Result<()> { + Self::encode_framing( + w, + ExtensionFieldTypeId::DraftIdentification, + data.len(), + minimum_size, + version, + )?; + + w.write_all(data.as_bytes())?; + + Self::encode_padding(w, data.len(), minimum_size)?; + + Ok(()) + } + fn decode_unique_identifier( message: &'a [u8], ) -> Result> { @@ -375,6 +443,18 @@ impl<'a> ExtensionField<'a> { }) } + #[cfg(feature = "ntpv5")] + fn decode_draft_identification( + message: &'a [u8], + ) -> Result> { + let di = match core::str::from_utf8(message) { + Ok(di) if di.is_ascii() => di, + _ => return Err(super::v5::V5Error::InvalidDraftIdentification.into()), + }; + + Ok(ExtensionField::DraftIdentification(Cow::Borrowed(di))) + } + fn decode(raw: RawExtensionField<'a>) -> Result> { type EF<'a> = ExtensionField<'a>; type TypeId = ExtensionFieldTypeId; @@ -385,6 +465,8 @@ impl<'a> ExtensionField<'a> { TypeId::UniqueIdentifier => EF::decode_unique_identifier(message), TypeId::NtsCookie => EF::decode_nts_cookie(message), TypeId::NtsCookiePlaceholder => EF::decode_nts_cookie_placeholder(message), + #[cfg(feature = "ntpv5")] + TypeId::DraftIdentification => EF::decode_draft_identification(message), type_id => EF::decode_unknown(type_id.to_type_id(), message), } } @@ -426,6 +508,7 @@ impl<'a> ExtensionFieldData<'a> { &self, w: &mut Cursor<&mut [u8]>, cipher: &(impl CipherProvider + ?Sized), + version: ExtensionHeaderVersion, ) -> std::io::Result<()> { if !self.authenticated.is_empty() || !self.encrypted.is_empty() { let cipher = match cipher.get(&self.authenticated) { @@ -438,13 +521,13 @@ impl<'a> ExtensionFieldData<'a> { let minimum_size = 16; for field in &self.authenticated { - field.serialize(w, minimum_size)?; + field.serialize(w, minimum_size, version)?; } // RFC 8915, section 5.5: contrary to the RFC 7822 requirement that fields have a minimum length of 16 or 28 octets, // encrypted extension fields MAY be arbitrarily short (but still MUST be a multiple of 4 octets in length) // hence we don't provide a minimum size here - ExtensionField::encode_encrypted(w, &self.encrypted, cipher.as_ref())?; + ExtensionField::encode_encrypted(w, &self.encrypted, cipher.as_ref(), version)?; } // per RFC 7822, section 7.5.1.4. @@ -452,7 +535,7 @@ impl<'a> ExtensionFieldData<'a> { while let Some(field) = it.next() { let is_last = it.peek().is_none(); let minimum_size = if is_last { 28 } else { 16 }; - field.serialize(w, minimum_size)?; + field.serialize(w, minimum_size, version)?; } Ok(()) @@ -463,6 +546,7 @@ impl<'a> ExtensionFieldData<'a> { data: &'a [u8], header_size: usize, cipher: &impl CipherProvider, + version: ExtensionHeaderVersion, ) -> Result, ParsingError>> { use ExtensionField::InvalidNtsEncryptedField; @@ -475,9 +559,10 @@ impl<'a> ExtensionFieldData<'a> { &data[header_size..], Mac::MAXIMUM_SIZE, RawExtensionField::V4_UNENCRYPTED_MINIMUM_SIZE, + version, ) { let (offset, field) = field.map_err(|e| e.generalize())?; - size = offset + field.wire_length(); + size = offset + field.wire_length(version); match field.type_id { ExtensionFieldTypeId::NtsEncryptedField => { let encrypted = RawEncryptedField::from_message_bytes(field.message_bytes) @@ -492,18 +577,21 @@ impl<'a> ExtensionFieldData<'a> { } }; - let encrypted_fields = - match encrypted.decrypt(cipher.as_ref(), &data[..header_size + offset]) { - Ok(encrypted_fields) => encrypted_fields, - Err(e) => { - // early return if it's anything but a decrypt error - e.get_decrypt_error()?; + let encrypted_fields = match encrypted.decrypt( + cipher.as_ref(), + &data[..header_size + offset], + version, + ) { + Ok(encrypted_fields) => encrypted_fields, + Err(e) => { + // early return if it's anything but a decrypt error + e.get_decrypt_error()?; - efdata.untrusted.push(InvalidNtsEncryptedField); - is_valid_nts = false; - continue; - } - }; + efdata.untrusted.push(InvalidNtsEncryptedField); + is_valid_nts = false; + continue; + } + }; // for the current ciphers we allow in non-test code, // the nonce should always be 16 bytes @@ -580,6 +668,7 @@ impl<'a> RawEncryptedField<'a> { &self, cipher: &dyn Cipher, aad: &[u8], + version: ExtensionHeaderVersion, ) -> Result>, ParsingError>> { let plaintext = match cipher.decrypt(self.nonce, self.ciphertext, aad) { Ok(plain) => plain, @@ -590,19 +679,48 @@ impl<'a> RawEncryptedField<'a> { } }; - RawExtensionField::deserialize_sequence(&plaintext, 0, RawExtensionField::BARE_MINIMUM_SIZE) - .map(|encrypted_field| { - let encrypted_field = encrypted_field.map_err(|e| e.generalize())?.1; - if encrypted_field.type_id == ExtensionFieldTypeId::NtsEncryptedField { - // TODO: Discuss whether we want this check - Err(ParsingError::MalformedNtsExtensionFields) - } else { - Ok(ExtensionField::decode(encrypted_field) - .map_err(|e| e.generalize())? - .into_owned()) - } - }) - .collect() + RawExtensionField::deserialize_sequence( + &plaintext, + 0, + RawExtensionField::BARE_MINIMUM_SIZE, + version, + ) + .map(|encrypted_field| { + let encrypted_field = encrypted_field.map_err(|e| e.generalize())?.1; + if encrypted_field.type_id == ExtensionFieldTypeId::NtsEncryptedField { + // TODO: Discuss whether we want this check + Err(ParsingError::MalformedNtsExtensionFields) + } else { + Ok(ExtensionField::decode(encrypted_field) + .map_err(|e| e.generalize())? + .into_owned()) + } + }) + .collect() + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ExtensionHeaderVersion { + V4, + #[cfg(feature = "ntpv5")] + V5, +} + +#[cfg(feature = "__internal-fuzz")] +impl<'a> arbitrary::Arbitrary<'a> for ExtensionHeaderVersion { + #[cfg(not(feature = "ntpv5"))] + fn arbitrary(_u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + Ok(Self::V4) + } + + #[cfg(feature = "ntpv5")] + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + Ok(if bool::arbitrary(u)? { + Self::V4 + } else { + Self::V5 + }) } } @@ -618,14 +736,16 @@ impl<'a> RawExtensionField<'a> { const BARE_MINIMUM_SIZE: usize = 4; const V4_UNENCRYPTED_MINIMUM_SIZE: usize = 4; - fn wire_length(&self) -> usize { + fn wire_length(&self, version: ExtensionHeaderVersion) -> usize { // field type + length + value + padding let length = 2 + 2 + self.message_bytes.len(); - // All extension fields are zero-padded to a word (four octets) boundary. - // - // message_bytes should include this padding, so this should already be true - debug_assert_eq!(length % 4, 0); + if version == ExtensionHeaderVersion::V4 { + // All extension fields are zero-padded to a word (four octets) boundary. + // + // message_bytes should include this padding, so this should already be true + debug_assert_eq!(length % 4, 0); + } next_multiple_of_usize(length, 4) } @@ -633,6 +753,7 @@ impl<'a> RawExtensionField<'a> { fn deserialize( data: &'a [u8], minimum_size: usize, + version: ExtensionHeaderVersion, ) -> Result> { use ParsingError::IncorrectLength; @@ -646,9 +767,13 @@ impl<'a> RawExtensionField<'a> { // the entire extension field in octets, including the Padding field. let field_length = u16::from_be_bytes([b2, b3]) as usize; - // padding is up to a multiple of 4 bytes, so a valid field length is divisible by 4 - if field_length < minimum_size || field_length % 4 != 0 { - return Err(ParsingError::IncorrectLength); + if field_length < minimum_size { + return Err(IncorrectLength); + } + + // In NTPv4: padding is up to a multiple of 4 bytes, so a valid field length is divisible by 4 + if version == ExtensionHeaderVersion::V4 && field_length % 4 != 0 { + return Err(IncorrectLength); } // because the field length includes padding, the message bytes may not exactly match the input @@ -664,6 +789,7 @@ impl<'a> RawExtensionField<'a> { buffer: &'a [u8], cutoff: usize, minimum_size: usize, + version: ExtensionHeaderVersion, ) -> impl Iterator< Item = Result<(usize, RawExtensionField<'a>), ParsingError>, > + 'a { @@ -672,6 +798,7 @@ impl<'a> RawExtensionField<'a> { cutoff, minimum_size, offset: 0, + version, } } } @@ -680,6 +807,7 @@ struct ExtensionFieldStreamer<'a> { cutoff: usize, minimum_size: usize, offset: usize, + version: ExtensionHeaderVersion, } impl<'a> Iterator for ExtensionFieldStreamer<'a> { @@ -692,10 +820,10 @@ impl<'a> Iterator for ExtensionFieldStreamer<'a> { return None; } - match RawExtensionField::deserialize(remaining, self.minimum_size) { + match RawExtensionField::deserialize(remaining, self.minimum_size, self.version) { Ok(field) => { let offset = self.offset; - self.offset += field.wire_length(); + self.offset += field.wire_length(self.version); Some(Ok((offset, field))) } Err(error) => { @@ -724,7 +852,7 @@ const fn next_multiple_of_usize(lhs: usize, rhs: usize) -> usize { mod tests { use crate::{ keyset::KeySet, - packet::{extensionfields::ExtensionFieldTypeId, AesSivCmac256}, + packet::{extension_fields::ExtensionFieldTypeId, AesSivCmac256}, }; use super::*; @@ -741,7 +869,13 @@ mod tests { fn test_unique_identifier() { let identifier: Vec<_> = (0..16).collect(); let mut w = vec![]; - ExtensionField::encode_unique_identifier(&mut w, &identifier, 0).unwrap(); + ExtensionField::encode_unique_identifier( + &mut w, + &identifier, + 0, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!( w, @@ -753,7 +887,7 @@ mod tests { fn test_nts_cookie() { let cookie: Vec<_> = (0..16).collect(); let mut w = vec![]; - ExtensionField::encode_nts_cookie(&mut w, &cookie, 0).unwrap(); + ExtensionField::encode_nts_cookie(&mut w, &cookie, 0, ExtensionHeaderVersion::V4).unwrap(); assert_eq!( w, @@ -766,7 +900,13 @@ mod tests { const COOKIE_LENGTH: usize = 16; let mut w = vec![]; - ExtensionField::encode_nts_cookie_placeholder(&mut w, COOKIE_LENGTH as u16, 0).unwrap(); + ExtensionField::encode_nts_cookie_placeholder( + &mut w, + COOKIE_LENGTH as u16, + 0, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!( w, @@ -798,7 +938,7 @@ mod tests { fn test_unknown() { let data: Vec<_> = (0..16).collect(); let mut w = vec![]; - ExtensionField::encode_unknown(&mut w, 42, &data, 0).unwrap(); + ExtensionField::encode_unknown(&mut w, 42, &data, 0, ExtensionHeaderVersion::V4).unwrap(); assert_eq!( w, @@ -806,6 +946,60 @@ mod tests { ); } + #[cfg(feature = "ntpv5")] + #[test] + fn draft_identification() { + let test_id = crate::packet::v5::DRAFT_VERSION; + let len = u16::try_from(4 + test_id.len()).unwrap(); + let mut data = vec![]; + data.extend(&[0xF5, 0xFF]); // Type + data.extend(&len.to_be_bytes()); // Length + data.extend(test_id.as_bytes()); // Payload + data.extend(&[0]); // Padding + + let raw = RawExtensionField::deserialize(&data, 4, ExtensionHeaderVersion::V5).unwrap(); + let ef = ExtensionField::decode(raw).unwrap(); + + let ExtensionField::DraftIdentification(ref parsed) = ef else { + panic!("Unexpected extensionfield {ef:?}... expected DraftIdentification"); + }; + + assert_eq!(parsed, test_id); + + let mut out = vec![]; + ef.serialize(&mut out, 4, ExtensionHeaderVersion::V5) + .unwrap(); + + assert_eq!(&out, &data); + } + + #[cfg(feature = "ntpv5")] + #[test] + fn extension_field_length() { + let data: Vec<_> = (0..21).collect(); + let mut w = vec![]; + ExtensionField::encode_unknown(&mut w, 42, &data, 16, ExtensionHeaderVersion::V4).unwrap(); + let raw: RawExtensionField<'_> = + RawExtensionField::deserialize(&w, 16, ExtensionHeaderVersion::V4).unwrap(); + + // v4 extension field header length includes padding bytes + assert_eq!(w[3], 28); + assert_eq!(w.len(), 28); + assert_eq!(raw.message_bytes.len(), 24); + assert_eq!(raw.wire_length(ExtensionHeaderVersion::V4), 28); + + let mut w = vec![]; + ExtensionField::encode_unknown(&mut w, 42, &data, 16, ExtensionHeaderVersion::V5).unwrap(); + let raw: RawExtensionField<'_> = + RawExtensionField::deserialize(&w, 16, ExtensionHeaderVersion::V5).unwrap(); + + // v5 extension field header length does not include padding bytes + assert_eq!(w[3], 25); + assert_eq!(w.len(), 28); + assert_eq!(raw.message_bytes.len(), 21); + assert_eq!(raw.wire_length(ExtensionHeaderVersion::V5), 28); + } + #[test] fn extension_field_minimum_size() { let minimum_size = 32; @@ -813,20 +1007,33 @@ mod tests { let data: Vec<_> = (0..16).collect(); let mut w = vec![]; - ExtensionField::encode_unique_identifier(&mut w, &data, minimum_size).unwrap(); + ExtensionField::encode_unique_identifier( + &mut w, + &data, + minimum_size, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size).unwrap(); + ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size, ExtensionHeaderVersion::V4) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_nts_cookie_placeholder(&mut w, data.len() as u16, minimum_size) - .unwrap(); + ExtensionField::encode_nts_cookie_placeholder( + &mut w, + data.len() as u16, + minimum_size, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size).unwrap(); + ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size, ExtensionHeaderVersion::V4) + .unwrap(); assert_eq!(w.len(), expected_size); // NOTE: encryped fields do not have a minimum_size @@ -839,20 +1046,33 @@ mod tests { let data: Vec<_> = (0..15).collect(); // 15 bytes, so padding is needed let mut w = vec![]; - ExtensionField::encode_unique_identifier(&mut w, &data, minimum_size).unwrap(); + ExtensionField::encode_unique_identifier( + &mut w, + &data, + minimum_size, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size).unwrap(); + ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size, ExtensionHeaderVersion::V4) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_nts_cookie_placeholder(&mut w, data.len() as u16, minimum_size) - .unwrap(); + ExtensionField::encode_nts_cookie_placeholder( + &mut w, + data.len() as u16, + minimum_size, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; - ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size).unwrap(); + ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size, ExtensionHeaderVersion::V4) + .unwrap(); assert_eq!(w.len(), expected_size); let mut w = [0u8; 128]; @@ -862,7 +1082,13 @@ mod tests { let fields_to_encrypt = [ExtensionField::UniqueIdentifier(Cow::Borrowed( data.as_slice(), ))]; - ExtensionField::encode_encrypted(&mut cursor, &fields_to_encrypt, &cipher).unwrap(); + ExtensionField::encode_encrypted( + &mut cursor, + &fields_to_encrypt, + &cipher, + ExtensionHeaderVersion::V4, + ) + .unwrap(); assert_eq!( cursor.position() as usize, 2 + 6 + c2s.len() + expected_size @@ -885,14 +1111,25 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - ExtensionField::encode_encrypted(&mut cursor, &fields_to_encrypt, &cipher).unwrap(); + ExtensionField::encode_encrypted( + &mut cursor, + &fields_to_encrypt, + &cipher, + ExtensionHeaderVersion::V4, + ) + .unwrap(); let expected_length = 2 + 6 + next_multiple_of_usize(nonce_length, 4) + plaintext_length; assert_eq!(cursor.position() as usize, expected_length,); let message_bytes = &w.as_ref()[..expected_length]; - let mut it = RawExtensionField::deserialize_sequence(message_bytes, 0, 0); + let mut it = RawExtensionField::deserialize_sequence( + message_bytes, + 0, + 0, + ExtensionHeaderVersion::V4, + ); let field = it.next().unwrap().unwrap(); assert!(it.next().is_none()); @@ -905,7 +1142,9 @@ mod tests { }, ) => { let raw = RawEncryptedField::from_message_bytes(message_bytes).unwrap(); - let decrypted_fields = raw.decrypt(&cipher, &[]).unwrap(); + let decrypted_fields = raw + .decrypt(&cipher, &[], ExtensionHeaderVersion::V4) + .unwrap(); assert_eq!(decrypted_fields, fields_to_encrypt); } _ => panic!("invalid"), @@ -927,7 +1166,9 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - assert!(data.serialize(&mut cursor, &cipher).is_err()); + assert!(data + .serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .is_err()); } // but succeed when the cipher is not needed @@ -940,7 +1181,9 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - assert!(data.serialize(&mut cursor, &cipher).is_ok()); + assert!(data + .serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .is_ok()); } } @@ -959,7 +1202,8 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &cipher).unwrap(); + data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; @@ -984,7 +1228,8 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &cipher).unwrap(); + data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; @@ -1010,14 +1255,16 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &cipher).unwrap(); + data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; let cipher = crate::packet::crypto::NoCipher; - let result = ExtensionFieldData::deserialize(slice, 0, &cipher).unwrap(); + let result = + ExtensionFieldData::deserialize(slice, 0, &cipher, ExtensionHeaderVersion::V4).unwrap(); let DeserializedExtensionField { efdata, @@ -1049,14 +1296,16 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &cipher).unwrap(); + data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; let cipher = crate::packet::crypto::NoCipher; - let result = ExtensionFieldData::deserialize(slice, 0, &cipher).unwrap_err(); + let result = ExtensionFieldData::deserialize(slice, 0, &cipher, ExtensionHeaderVersion::V4) + .unwrap_err(); let ParsingError::DecryptError(InvalidNtsExtensionField { efdata, @@ -1089,7 +1338,8 @@ mod tests { let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &cipher).unwrap(); + data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; @@ -1098,7 +1348,8 @@ mod tests { let c2s = [0; 32]; let cipher = AesSivCmac256::new(c2s.into()); - let result = ExtensionFieldData::deserialize(slice, 0, &cipher).unwrap_err(); + let result = ExtensionFieldData::deserialize(slice, 0, &cipher, ExtensionHeaderVersion::V4) + .unwrap_err(); let ParsingError::DecryptError(InvalidNtsExtensionField { efdata, @@ -1133,12 +1384,14 @@ mod tests { let mut w = [0u8; 256]; let mut cursor = Cursor::new(w.as_mut_slice()); - data.serialize(&mut cursor, &keyset).unwrap(); + data.serialize(&mut cursor, &keyset, ExtensionHeaderVersion::V4) + .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; - let result = ExtensionFieldData::deserialize(slice, 0, &keyset).unwrap(); + let result = + ExtensionFieldData::deserialize(slice, 0, &keyset, ExtensionHeaderVersion::V4).unwrap(); let DeserializedExtensionField { efdata, diff --git a/ntp-proto/src/packet/mod.rs b/ntp-proto/src/packet/mod.rs index 7b0cad27d..76101b9bb 100644 --- a/ntp-proto/src/packet/mod.rs +++ b/ntp-proto/src/packet/mod.rs @@ -11,19 +11,22 @@ use crate::{ time_types::{NtpDuration, NtpTimestamp, PollInterval}, }; -use self::{error::ParsingError, extensionfields::ExtensionFieldData, mac::Mac}; +use self::{error::ParsingError, extension_fields::ExtensionFieldData, mac::Mac}; mod crypto; mod error; -mod extensionfields; +mod extension_fields; mod mac; +#[cfg(feature = "ntpv5")] +mod v5; + pub use crypto::{ AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError, EncryptResult, NoCipher, }; pub use error::PacketParsingError; -pub use extensionfields::ExtensionField; +pub use extension_fields::{ExtensionField, ExtensionHeaderVersion}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum NtpLeapIndicator { @@ -118,6 +121,8 @@ pub struct NtpPacket<'a> { enum NtpHeader { V3(NtpHeaderV3V4), V4(NtpHeaderV3V4), + #[cfg(feature = "ntpv5")] + V5(v5::NtpHeaderV5), } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -146,7 +151,7 @@ pub struct RequestIdentifier { } impl NtpHeaderV3V4 { - const LENGTH: usize = 48; + const WIRE_LENGTH: usize = 48; /// A new, empty NtpHeader fn new() -> Self { @@ -167,7 +172,7 @@ impl NtpHeaderV3V4 { } fn deserialize(data: &[u8]) -> Result<(Self, usize), ParsingError> { - if data.len() < Self::LENGTH { + if data.len() < Self::WIRE_LENGTH { return Err(ParsingError::IncorrectLength); } @@ -186,7 +191,7 @@ impl NtpHeaderV3V4 { receive_timestamp: NtpTimestamp::from_bits(data[32..40].try_into().unwrap()), transmit_timestamp: NtpTimestamp::from_bits(data[40..48].try_into().unwrap()), }, - Self::LENGTH, + Self::WIRE_LENGTH, )) } @@ -320,7 +325,7 @@ impl<'a> NtpPacket<'a> { let (header, header_size) = NtpHeaderV3V4::deserialize(data).map_err(|e| e.generalize())?; - let contruct_packet = |remaining_bytes: &'a [u8], efdata| { + let construct_packet = |remaining_bytes: &'a [u8], efdata| { let mac = if !remaining_bytes.is_empty() { Some(Mac::deserialize(remaining_bytes)?) } else { @@ -336,9 +341,59 @@ impl<'a> NtpPacket<'a> { Ok::<_, ParsingError>(packet) }; - match ExtensionFieldData::deserialize(data, header_size, cipher) { + match ExtensionFieldData::deserialize( + data, + header_size, + cipher, + ExtensionHeaderVersion::V4, + ) { + Ok(decoded) => { + let packet = construct_packet(decoded.remaining_bytes, decoded.efdata) + .map_err(|e| e.generalize())?; + + Ok((packet, decoded.cookie)) + } + Err(e) => { + // return early if it is anything but a decrypt error + let invalid = e.get_decrypt_error()?; + + let packet = construct_packet(invalid.remaining_bytes, invalid.efdata) + .map_err(|e| e.generalize())?; + + Err(ParsingError::DecryptError(packet)) + } + } + } + #[cfg(feature = "ntpv5")] + 5 => { + let (header, header_size) = + v5::NtpHeaderV5::deserialize(data).map_err(|e| e.generalize())?; + + let construct_packet = |remaining_bytes: &'a [u8], efdata| { + let mac = if !remaining_bytes.is_empty() { + Some(Mac::deserialize(remaining_bytes)?) + } else { + None + }; + + let packet = NtpPacket { + header: NtpHeader::V5(header), + efdata, + mac, + }; + + Ok::<_, ParsingError>(packet) + }; + + // TODO: Check extension field handling in V5 + match ExtensionFieldData::deserialize( + data, + header_size, + cipher, + ExtensionHeaderVersion::V5, + ) { Ok(decoded) => { - let packet = contruct_packet(decoded.remaining_bytes, decoded.efdata) + let packet = construct_packet(decoded.remaining_bytes, decoded.efdata) .map_err(|e| e.generalize())?; Ok((packet, decoded.cookie)) @@ -347,7 +402,7 @@ impl<'a> NtpPacket<'a> { // return early if it is anything but a decrypt error let invalid = e.get_decrypt_error()?; - let packet = contruct_packet(invalid.remaining_bytes, invalid.efdata) + let packet = construct_packet(invalid.remaining_bytes, invalid.efdata) .map_err(|e| e.generalize())?; Err(ParsingError::DecryptError(packet)) @@ -379,11 +434,19 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.serialize(w, 3)?, NtpHeader::V4(header) => header.serialize(w, 4)?, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.serialize(w)?, }; match self.header { NtpHeader::V3(_) => { /* No extension fields in V3 */ } - NtpHeader::V4(_) => self.efdata.serialize(w, cipher)?, + NtpHeader::V4(_) => self + .efdata + .serialize(w, cipher, ExtensionHeaderVersion::V4)?, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_) => self + .efdata + .serialize(w, cipher, ExtensionHeaderVersion::V5)?, } if let Some(ref mac) = self.mac { @@ -480,6 +543,29 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => NtpPacket { + // TODO deduplicate extension handling with V4 + header: NtpHeader::V5(v5::NtpHeaderV5::timestamp_response( + system, + header, + recv_timestamp, + clock, + )), + efdata: ExtensionFieldData { + authenticated: vec![], + encrypted: vec![], + // Ignore encrypted so as not to accidentaly leak anything + untrusted: input + .efdata + .untrusted + .into_iter() + .chain(input.efdata.authenticated) + .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) + .collect(), + }, + mac: None, + }, } } @@ -537,6 +623,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("No NTS support for V5 yet"), } } @@ -563,6 +651,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("NTPv5 does not have KISS codes yet"), } } @@ -583,6 +673,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("No NTS support yet"), } } @@ -609,6 +701,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("NTPv5 does not have KISS codes yet"), } } @@ -629,6 +723,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("No NTS support for NTPv5 yet"), } } @@ -650,6 +746,8 @@ impl<'a> NtpPacket<'a> { }, mac: None, }, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("No NTS support for NTPv5 yet"), } } } @@ -666,6 +764,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.leap, NtpHeader::V4(header) => header.leap, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.leap, } } @@ -673,6 +773,13 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.mode, NtpHeader::V4(header) => header.mode, + + // FIXME long term the return type should change to capture both mode types + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => match header.mode { + v5::NtpMode::Request => NtpAssociationMode::Client, + v5::NtpMode::Response => NtpAssociationMode::Server, + }, } } @@ -680,6 +787,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.stratum, NtpHeader::V4(header) => header.stratum, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.stratum, } } @@ -687,6 +796,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.precision, NtpHeader::V4(header) => header.precision, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.precision, } } @@ -694,6 +805,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.root_delay, NtpHeader::V4(header) => header.root_delay, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.root_delay, } } @@ -701,6 +814,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.root_dispersion, NtpHeader::V4(header) => header.root_dispersion, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.root_dispersion, } } @@ -708,6 +823,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.receive_timestamp, NtpHeader::V4(header) => header.receive_timestamp, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.receive_timestamp, } } @@ -715,6 +832,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.transmit_timestamp, NtpHeader::V4(header) => header.transmit_timestamp, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => header.transmit_timestamp, } } @@ -722,6 +841,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.reference_id, NtpHeader::V4(header) => header.reference_id, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("NTPv5 does not have reference IDs"), } } @@ -729,6 +850,8 @@ impl<'a> NtpPacket<'a> { match self.header { NtpHeader::V3(header) => header.stratum == 0, NtpHeader::V4(header) => header.stratum == 0, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("NTPv5 does not have kiss codes yet"), } } @@ -774,6 +897,11 @@ impl<'a> NtpPacket<'a> { NtpHeader::V4(header) => { header.origin_timestamp == identifier.expected_origin_timestamp } + #[cfg(feature = "ntpv5")] + NtpHeader::V5(header) => { + header.client_cookie + == v5::NtpClientCookie::from_ntp_timestamp(identifier.expected_origin_timestamp) + } } } } @@ -810,6 +938,14 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.mode = mode, NtpHeader::V4(ref mut header) => header.mode = mode, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => { + header.mode = match mode { + NtpAssociationMode::Client => v5::NtpMode::Request, + NtpAssociationMode::Server => v5::NtpMode::Response, + _ => todo!("NTPv5 can only handle client-server"), + } + } } } @@ -817,6 +953,11 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.origin_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.origin_timestamp = timestamp, + #[cfg(feature = "ntpv5")] + // TODO can we just reuse the cookie as the origin timestamp? + NtpHeader::V5(ref mut header) => { + header.client_cookie = v5::NtpClientCookie::from_ntp_timestamp(timestamp) + } } } @@ -824,6 +965,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.transmit_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.transmit_timestamp = timestamp, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.transmit_timestamp = timestamp, } } @@ -831,6 +974,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.receive_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.receive_timestamp = timestamp, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.receive_timestamp = timestamp, } } @@ -838,6 +983,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.precision = precision, NtpHeader::V4(ref mut header) => header.precision = precision, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.precision = precision, } } @@ -845,6 +992,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.leap = leap, NtpHeader::V4(ref mut header) => header.leap = leap, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.leap = leap, } } @@ -852,6 +1001,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.stratum = stratum, NtpHeader::V4(ref mut header) => header.stratum = stratum, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.stratum = stratum, } } @@ -859,6 +1010,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.reference_id = reference_id, NtpHeader::V4(ref mut header) => header.reference_id = reference_id, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(_header) => todo!("NTPv5 does not have reference IDs"), } } @@ -866,6 +1019,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.root_delay = root_delay, NtpHeader::V4(ref mut header) => header.root_delay = root_delay, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.root_delay = root_delay, } } @@ -873,6 +1028,8 @@ impl<'a> NtpPacket<'a> { match &mut self.header { NtpHeader::V3(ref mut header) => header.root_dispersion = root_dispersion, NtpHeader::V4(ref mut header) => header.root_dispersion = root_dispersion, + #[cfg(feature = "ntpv5")] + NtpHeader::V5(ref mut header) => header.root_dispersion = root_dispersion, } } } @@ -1070,12 +1227,17 @@ mod tests { assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x14\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); - let packet = b"\x2B\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; - assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x34\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x3B\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); + + #[cfg(not(feature = "ntpv5"))] + { + // Version 5 packet should not parse without the ntpv5 feature + let packet = b"\x2C\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; + assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); + } } #[test] diff --git a/ntp-proto/src/packet/v5/error.rs b/ntp-proto/src/packet/v5/error.rs new file mode 100644 index 000000000..03d7d5100 --- /dev/null +++ b/ntp-proto/src/packet/v5/error.rs @@ -0,0 +1,34 @@ +use crate::packet::error::ParsingError; +use std::fmt::{Display, Formatter}; + +#[derive(Debug)] +pub enum V5Error { + InvalidDraftIdentification, + MalformedTimescale, + MalformedMode, + InvalidFlags, +} + +impl V5Error { + /// `const` alternative to `.into()` + pub const fn into_parse_err(self) -> ParsingError { + ParsingError::V5(self) + } +} + +impl Display for V5Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidDraftIdentification => f.write_str("Draft Identification invalid"), + Self::MalformedTimescale => f.write_str("Malformed timescale"), + Self::MalformedMode => f.write_str("Malformed mode"), + Self::InvalidFlags => f.write_str("Invalid flags specified"), + } + } +} + +impl From for crate::packet::error::ParsingError { + fn from(value: V5Error) -> Self { + Self::V5(value) + } +} diff --git a/ntp-proto/src/packet/v5/extension_fields.rs b/ntp-proto/src/packet/v5/extension_fields.rs new file mode 100644 index 000000000..45cf8d6f2 --- /dev/null +++ b/ntp-proto/src/packet/v5/extension_fields.rs @@ -0,0 +1,87 @@ +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Type { + DraftIdentification, + Padding, + Mac, + ReferenceIdRequest, + ReferenceIdResponse, + ServerInformation, + Correction, + ReferenceTimestamp, + MonotonicReceiveTimestamp, + SecondaryReceiveTimestamp, + Unknown(u16), +} + +impl Type { + pub const fn from_bits(bits: u16) -> Self { + match bits { + 0xF5FF => Self::DraftIdentification, + 0xF501 => Self::Padding, + 0xF502 => Self::Mac, + 0xF503 => Self::ReferenceIdRequest, + 0xF504 => Self::ReferenceIdResponse, + 0xF505 => Self::ServerInformation, + 0xF506 => Self::Correction, + 0xF507 => Self::ReferenceTimestamp, + 0xF508 => Self::MonotonicReceiveTimestamp, + 0xF509 => Self::SecondaryReceiveTimestamp, + other => Self::Unknown(other), + } + } + + pub const fn to_bits(self) -> u16 { + match self { + Self::DraftIdentification => 0xF5FF, + Self::Padding => 0xF501, + Self::Mac => 0xF502, + Self::ReferenceIdRequest => 0xF503, + Self::ReferenceIdResponse => 0xF504, + Self::ServerInformation => 0xF505, + Self::Correction => 0xF506, + Self::ReferenceTimestamp => 0xF507, + Self::MonotonicReceiveTimestamp => 0xF508, + Self::SecondaryReceiveTimestamp => 0xF509, + Self::Unknown(other) => other, + } + } + + #[cfg(test)] + fn all_known() -> impl Iterator { + [ + Self::DraftIdentification, + Self::Padding, + Self::Mac, + Self::ReferenceIdRequest, + Self::ReferenceIdResponse, + Self::ServerInformation, + Self::Correction, + Self::ReferenceTimestamp, + Self::MonotonicReceiveTimestamp, + Self::SecondaryReceiveTimestamp, + ] + .into_iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn type_round_trip() { + for i in 0..=u16::MAX { + let ty = Type::from_bits(i); + assert_eq!(i, ty.to_bits()); + } + + for ty in Type::all_known() { + let bits = ty.to_bits(); + let ty2 = Type::from_bits(bits); + assert_eq!(ty, ty2); + + let bits2 = ty2.to_bits(); + assert_eq!(bits, bits2); + } + } +} diff --git a/ntp-proto/src/packet/v5/mod.rs b/ntp-proto/src/packet/v5/mod.rs new file mode 100644 index 000000000..c596fd4e9 --- /dev/null +++ b/ntp-proto/src/packet/v5/mod.rs @@ -0,0 +1,469 @@ +#![warn(clippy::missing_const_for_fn)] +use crate::{NtpClock, NtpDuration, NtpLeapIndicator, NtpTimestamp, SystemSnapshot}; +use rand::random; + +mod error; +#[allow(dead_code)] +pub mod extension_fields; + +use crate::packet::error::ParsingError; +pub use error::V5Error; + +#[allow(dead_code)] +pub(crate) const DRAFT_VERSION: &str = "draft-ietf-ntp-ntpv5-00"; + +#[repr(u8)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum NtpMode { + Request = 3, + Response = 4, +} + +impl NtpMode { + const fn from_bits(bits: u8) -> Result> { + Ok(match bits { + 3 => Self::Request, + 4 => Self::Response, + _ => return Err(V5Error::MalformedMode.into_parse_err()), + }) + } + + const fn to_bits(self) -> u8 { + self as u8 + } + + #[allow(dead_code)] + pub(crate) const fn is_request(self) -> bool { + matches!(self, Self::Request) + } + + #[allow(dead_code)] + pub(crate) const fn is_response(self) -> bool { + matches!(self, Self::Response) + } +} + +#[repr(u8)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum NtpTimescale { + Utc = 0, + Tai = 1, + Ut1 = 2, + LeapSmearedUtc = 3, +} + +impl NtpTimescale { + const fn from_bits(bits: u8) -> Result> { + Ok(match bits { + 0 => Self::Utc, + 1 => Self::Tai, + 2 => Self::Ut1, + 3 => Self::LeapSmearedUtc, + _ => return Err(V5Error::MalformedTimescale.into_parse_err()), + }) + } + + const fn to_bits(self) -> u8 { + self as u8 + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NtpEra(pub u8); + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NtpFlags { + unknown_leap: bool, + interleaved_mode: bool, +} + +impl NtpFlags { + const fn from_bits(bits: [u8; 2]) -> Result> { + if bits[0] != 0x00 || bits[1] & 0b1111_1100 != 0 { + return Err(V5Error::InvalidFlags.into_parse_err()); + } + + Ok(Self { + unknown_leap: bits[1] & 0b01 != 0, + interleaved_mode: bits[1] & 0b10 != 0, + }) + } + + const fn as_bits(self) -> [u8; 2] { + let mut flags: u8 = 0; + + if self.unknown_leap { + flags |= 0b01; + } + + if self.interleaved_mode { + flags |= 0b10; + } + + [0x00, flags] + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NtpServerCookie(pub [u8; 8]); + +impl NtpServerCookie { + fn new_random() -> NtpServerCookie { + // TODO does this match entropy handling of the rest of the system? + Self(random()) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NtpClientCookie(pub [u8; 8]); + +impl NtpClientCookie { + pub const fn from_ntp_timestamp(ts: NtpTimestamp) -> Self { + Self(ts.to_bits()) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NtpHeaderV5 { + pub leap: NtpLeapIndicator, + pub mode: NtpMode, + pub stratum: u8, + pub poll: i8, + pub precision: i8, + pub timescale: NtpTimescale, + pub era: NtpEra, + pub flags: NtpFlags, + pub root_delay: NtpDuration, + pub root_dispersion: NtpDuration, + pub server_cookie: NtpServerCookie, + pub client_cookie: NtpClientCookie, + /// Time at the server when the request arrived from the client + pub receive_timestamp: NtpTimestamp, + /// Time at the server when the response left for the client + pub transmit_timestamp: NtpTimestamp, +} + +impl NtpHeaderV5 { + pub(crate) fn timestamp_response( + system: &SystemSnapshot, + input: Self, + recv_timestamp: NtpTimestamp, + clock: &C, + ) -> Self { + Self { + leap: system.time_snapshot.leap_indicator, + mode: NtpMode::Response, + stratum: system.stratum, + // TODO this changed in NTPv5 + poll: input.poll, + precision: system.time_snapshot.precision.log2(), + // TODO this is new in NTPv5 + timescale: NtpTimescale::Utc, + // TODO this is new in NTPv5 + era: NtpEra(0), + // TODO this is new in NTPv5 + flags: NtpFlags { + unknown_leap: false, + interleaved_mode: false, + }, + root_delay: system.time_snapshot.root_delay, + root_dispersion: system.time_snapshot.root_dispersion, + server_cookie: NtpServerCookie::new_random(), + client_cookie: input.client_cookie, + receive_timestamp: recv_timestamp, + transmit_timestamp: clock.now().expect("Failed to read time"), + } + } +} + +impl NtpHeaderV5 { + const WIRE_LENGTH: usize = 48; + const VERSION: u8 = 5; + + pub(crate) fn deserialize( + data: &[u8], + ) -> Result<(Self, usize), ParsingError> { + if data.len() < Self::WIRE_LENGTH { + return Err(ParsingError::IncorrectLength); + } + + let version = (data[0] >> 3) & 0b111; + if version != 5 { + return Err(ParsingError::InvalidVersion(version)); + } + + Ok(( + Self { + leap: NtpLeapIndicator::from_bits((data[0] & 0xC0) >> 6), + mode: NtpMode::from_bits(data[0] & 0x07)?, + stratum: data[1], + poll: data[2] as i8, + precision: data[3] as i8, + timescale: NtpTimescale::from_bits(data[4])?, + era: NtpEra(data[5]), + flags: NtpFlags::from_bits(data[6..8].try_into().unwrap())?, + root_delay: NtpDuration::from_bits_short(data[8..12].try_into().unwrap()), + root_dispersion: NtpDuration::from_bits_short(data[12..16].try_into().unwrap()), + server_cookie: NtpServerCookie(data[16..24].try_into().unwrap()), + client_cookie: NtpClientCookie(data[24..32].try_into().unwrap()), + receive_timestamp: NtpTimestamp::from_bits(data[32..40].try_into().unwrap()), + transmit_timestamp: NtpTimestamp::from_bits(data[40..48].try_into().unwrap()), + }, + Self::WIRE_LENGTH, + )) + } + + #[allow(dead_code)] + pub(crate) fn serialize(&self, w: &mut W) -> std::io::Result<()> { + w.write_all(&[(self.leap.to_bits() << 6) | (Self::VERSION << 3) | self.mode.to_bits()])?; + w.write_all(&[self.stratum, self.poll as u8, self.precision as u8])?; + w.write_all(&[self.timescale.to_bits()])?; + w.write_all(&[self.era.0])?; + w.write_all(&self.flags.as_bits())?; + w.write_all(&self.root_delay.to_bits_short())?; + w.write_all(&self.root_dispersion.to_bits_short())?; + w.write_all(&self.server_cookie.0)?; + w.write_all(&self.client_cookie.0)?; + w.write_all(&self.receive_timestamp.to_bits())?; + w.write_all(&self.transmit_timestamp.to_bits())?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn round_trip_timescale() { + for i in 0..=u8::MAX { + if let Ok(ts) = NtpTimescale::from_bits(i) { + assert_eq!(ts as u8, i); + } + } + } + + #[test] + fn flags() { + let flags = NtpFlags::from_bits([0x00, 0x00]).unwrap(); + assert!(!flags.unknown_leap); + assert!(!flags.interleaved_mode); + + let flags = NtpFlags::from_bits([0x00, 0x01]).unwrap(); + assert!(flags.unknown_leap); + assert!(!flags.interleaved_mode); + + let flags = NtpFlags::from_bits([0x00, 0x02]).unwrap(); + assert!(!flags.unknown_leap); + assert!(flags.interleaved_mode); + + let flags = NtpFlags::from_bits([0x00, 0x03]).unwrap(); + assert!(flags.unknown_leap); + assert!(flags.interleaved_mode); + + let result = NtpFlags::from_bits([0xFF, 0xFF]); + assert!(matches!( + result, + Err(ParsingError::V5(V5Error::InvalidFlags)) + )); + } + + #[test] + fn parse_request() { + #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields + #[rustfmt::skip] + let data = [ + // LI VN Mode + 0b_00_101_011, + // Stratum + 0x00, + // Poll + 0x05, + // Precision + 0x00, + // Timescale (0: UTC, 1: TAI, 2: UT1, 3: Leap-smeared UTC) + 0x02, + // Era + 0x00, + // Flags + 0x00, + 0b0000_00_1_0, + // Root Delay + 0x00, 0x00, 0x00, 0x00, + // Root Dispersion + 0x00, 0x00, 0x00, 0x00, + // Server Cookie + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Client Cookie + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + // Receive Timestamp + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Transmit Timestamp + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + let (parsed, len) = NtpHeaderV5::deserialize(&data).unwrap(); + + assert_eq!(len, 48); + assert_eq!(parsed.leap, NtpLeapIndicator::NoWarning); + assert!(parsed.mode.is_request()); + assert_eq!(parsed.stratum, 0); + assert_eq!(parsed.poll, 5); + assert_eq!(parsed.precision, 0); + assert_eq!(parsed.timescale, NtpTimescale::Ut1); + assert_eq!(parsed.era, NtpEra(0)); + assert!(parsed.flags.interleaved_mode); + assert!(!parsed.flags.unknown_leap); + assert!(parsed.flags.interleaved_mode); + assert_eq!(parsed.root_delay, NtpDuration::from_seconds(0.0)); + assert_eq!(parsed.root_dispersion, NtpDuration::from_seconds(0.0)); + assert_eq!(parsed.server_cookie, NtpServerCookie([0x0; 8])); + assert_eq!( + parsed.client_cookie, + NtpClientCookie([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]) + ); + assert_eq!(parsed.receive_timestamp, NtpTimestamp::from_fixed_int(0x0)); + assert_eq!(parsed.transmit_timestamp, NtpTimestamp::from_fixed_int(0x0)); + + let mut buffer: [u8; 48] = [0u8; 48]; + let mut cursor = Cursor::new(buffer.as_mut_slice()); + parsed.serialize(&mut cursor).unwrap(); + + assert_eq!(data, buffer); + } + + #[test] + fn parse_resonse() { + #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields + #[rustfmt::skip] + let data = [ + // LI VN Mode + 0b_00_101_100, + // Stratum + 0x04, + // Poll + 0x05, + // Precision + 0x06, + // Timescale (0: UTC, 1: TAI, 2: UT1, 3: Leap-smeared UTC) + 0x01, + // Era + 0x07, + // Flags + 0x00, + 0b0000_00_1_0, + // Root Delay + 0x00, 0x00, 0x02, 0x3f, + // Root Dispersion + 0x00, 0x00, 0x00, 0x42, + // Server Cookie + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + // Client Cookie + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + // Receive Timestamp + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + // Transmit Timestamp + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + ]; + + let (parsed, len) = NtpHeaderV5::deserialize(&data).unwrap(); + + assert_eq!(len, 48); + assert_eq!(parsed.leap, NtpLeapIndicator::NoWarning); + assert!(parsed.mode.is_response()); + assert_eq!(parsed.stratum, 4); + assert_eq!(parsed.poll, 5); + assert_eq!(parsed.precision, 6); + assert_eq!(parsed.timescale, NtpTimescale::Tai); + assert_eq!(parsed.era, NtpEra(7)); + assert!(parsed.flags.interleaved_mode); + assert!(!parsed.flags.unknown_leap); + assert!(parsed.flags.interleaved_mode); + assert_eq!( + parsed.root_delay, + NtpDuration::from_seconds(0.00877380371298031) + ); + assert_eq!( + parsed.root_dispersion, + NtpDuration::from_seconds(0.001007080078359479) + ); + assert_eq!( + parsed.server_cookie, + NtpServerCookie([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]) + ); + assert_eq!( + parsed.client_cookie, + NtpClientCookie([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]) + ); + assert_eq!( + parsed.receive_timestamp, + NtpTimestamp::from_fixed_int(0x1111111111111111) + ); + assert_eq!( + parsed.transmit_timestamp, + NtpTimestamp::from_fixed_int(0x2222222222222222) + ); + + let mut buffer: [u8; 48] = [0u8; 48]; + let mut cursor = Cursor::new(buffer.as_mut_slice()); + parsed.serialize(&mut cursor).unwrap(); + + assert_eq!(data, buffer); + } + + #[test] + fn test_encode_decode_roundtrip() { + for i in 0..=u8::MAX { + let header = NtpHeaderV5 { + leap: NtpLeapIndicator::from_bits(i % 4), + mode: NtpMode::from_bits(3 + (i % 2)).unwrap(), + stratum: i.wrapping_add(1), + poll: i.wrapping_add(3) as i8, + precision: i.wrapping_add(4) as i8, + timescale: NtpTimescale::from_bits(i % 4).unwrap(), + era: NtpEra(i.wrapping_add(6)), + flags: NtpFlags { + unknown_leap: i % 3 == 0, + interleaved_mode: i % 4 == 0, + }, + root_delay: NtpDuration::from_bits_short([i; 4]), + root_dispersion: NtpDuration::from_bits_short([i.wrapping_add(1); 4]), + server_cookie: NtpServerCookie([i.wrapping_add(2); 8]), + client_cookie: NtpClientCookie([i.wrapping_add(3); 8]), + receive_timestamp: NtpTimestamp::from_bits([i.wrapping_add(4); 8]), + transmit_timestamp: NtpTimestamp::from_bits([i.wrapping_add(5); 8]), + }; + + let mut buffer: [u8; 48] = [0u8; 48]; + let mut cursor = Cursor::new(buffer.as_mut_slice()); + header.serialize(&mut cursor).unwrap(); + + let (parsed, _) = NtpHeaderV5::deserialize(&buffer).unwrap(); + + assert_eq!(header, parsed); + } + } + + #[test] + fn fail_on_incorrect_length() { + let data: [u8; 47] = [0u8; 47]; + + assert!(matches!( + NtpHeaderV5::deserialize(&data), + Err(ParsingError::IncorrectLength) + )); + } + + #[test] + #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields + fn fail_on_incorrect_version() { + let mut data: [u8; 48] = [0u8; 48]; + data[0] = 0b_00_111_100; + + assert!(matches!( + NtpHeaderV5::deserialize(&data), + Err(ParsingError::InvalidVersion(7)) + )); + } +}