diff --git a/include/quicr/detail/ctrl_message_types.h b/include/quicr/detail/ctrl_message_types.h index 2fdc9db9..45b938fb 100644 --- a/include/quicr/detail/ctrl_message_types.h +++ b/include/quicr/detail/ctrl_message_types.h @@ -408,8 +408,9 @@ namespace quicr::messages { { if constexpr (std::is_arithmetic_v || std::is_enum_v) { if (static_cast(type) % 2 == 0) { - UintVar u_value(static_cast(value)); - return Bytes{ u_value.begin(), u_value.end() }; + const std::uint64_t val = static_cast(value); + auto* val_bytes = reinterpret_cast(&val); + return Bytes{ val_bytes, val_bytes + sizeof(val) }; } } @@ -459,7 +460,10 @@ namespace quicr::messages { { if constexpr (std::is_arithmetic_v) { if (static_cast(type) % 2 == 0) { - return static_cast(UintVar(extensions.at(static_cast(type))).Get()); + std::uint64_t val = 0; + const auto& bytes = extensions.at(static_cast(type)); + std::memcpy(&val, bytes.data(), std::min(bytes.size(), sizeof(val))); + return static_cast(val); } } @@ -477,7 +481,10 @@ namespace quicr::messages { { if constexpr (std::is_arithmetic_v) { if (static_cast(type) % 2 == 0) { - return static_cast(UintVar(immutable_extensions.at(static_cast(type))).Get()); + std::uint64_t val = 0; + const auto& bytes = immutable_extensions.at(static_cast(type)); + std::memcpy(&val, bytes.data(), std::min(bytes.size(), sizeof(val))); + return static_cast(val); } } @@ -523,8 +530,9 @@ namespace quicr::messages { { if constexpr (std::is_arithmetic_v || std::is_enum_v) { if (static_cast(type) % 2 == 0) { - UintVar u_value(static_cast(value)); - parameters.push_back({ type, Bytes{ u_value.begin(), u_value.end() } }); + const std::uint64_t val = static_cast(value); + auto* val_bytes = reinterpret_cast(&val); + parameters.push_back({ type, Bytes{ val_bytes, val_bytes + sizeof(val) } }); return *this; } } @@ -584,7 +592,9 @@ namespace quicr::messages { if constexpr (std::is_arithmetic_v) { if (static_cast(type) % 2 == 0) { - return static_cast(UintVar(bytes).Get()); + std::uint64_t val = 0; + std::memcpy(&val, bytes.data(), std::min(bytes.size(), sizeof(val))); + return static_cast(val); } } diff --git a/test/moq_ctrl_messages.cpp b/test/moq_ctrl_messages.cpp index 8c019870..37610b0f 100644 --- a/test/moq_ctrl_messages.cpp +++ b/test/moq_ctrl_messages.cpp @@ -810,3 +810,79 @@ TEST_CASE("uint16_t encode/decode") { IntegerEncodeDecode(true); } + +TEST_CASE("KeyValuePair even-type round-trip preserves values") +{ + const std::vector test_values = { + 0, 1, + 63, // Max 1-byte varint + 64, // Min 2-byte varint + 127, 128, 255, + 16383, // Max 2-byte varint + 16384, // Min 4-byte varint + 100000, + }; + + for (const auto value : test_values) { + CAPTURE(value); + + Parameters params; + params.Add(ParameterType::kDeliveryTimeout, value); + + Bytes buffer; + buffer << params; + + // Should have encoded as uintvar. + UintVar expected(value); + Bytes expected_bytes{ expected.begin(), expected.end() }; + REQUIRE(buffer.size() >= expected_bytes.size()); + Bytes tail(buffer.end() - expected_bytes.size(), buffer.end()); + CHECK_EQ(tail, expected_bytes); + + Parameters out; + BytesSpan span{ buffer }; + span >> out; + + // Roundtrip. + CHECK_NOTHROW(out.Get(ParameterType::kDeliveryTimeout)); + CHECK_EQ(out.Get(ParameterType::kDeliveryTimeout), value); + } +} + +TEST_CASE("TrackExtensions even-type round-trip preserves values") +{ + const std::vector test_values = { + 0, 1, + 63, // Max 1-byte varint + 64, // Min 2-byte varint + 127, 128, 255, + 16383, // Max 2-byte varint + 16384, // Min 4-byte varint + 100000, + }; + + for (const auto value : test_values) { + CAPTURE(value); + + TrackExtensions ext; + ext.Add(ExtensionType::kDeliveryTimeout, value); + + Bytes buffer; + buffer << ext; + + // Should have been encoded as uintvar. + UintVar expected(value); + Bytes expected_bytes{ expected.begin(), expected.end() }; + REQUIRE(buffer.size() >= expected_bytes.size()); + Bytes tail(buffer.end() - expected_bytes.size(), buffer.end()); + CHECK_EQ(tail, expected_bytes); + + TrackExtensions out; + BytesSpan span{ buffer }; + span >> out; + + // Roundtrip. + CHECK_NOTHROW(out.Get(ExtensionType::kDeliveryTimeout)); + CHECK_EQ(out.Get(ExtensionType::kDeliveryTimeout), value); + } +}