Skip to content

Commit de2ee3d

Browse files
committed
Randomize peer port.
1 parent cfc75ae commit de2ee3d

File tree

1 file changed

+64
-57
lines changed

1 file changed

+64
-57
lines changed

ntpd/src/daemon/peer.rs

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ pub(crate) struct PeerTask<C: 'static + NtpClock + Send, T: Wait> {
5757
_wait: PhantomData<T>,
5858
index: PeerId,
5959
clock: C,
60-
socket: Socket<SocketAddr, Connected>,
60+
interface: Option<InterfaceName>,
61+
timestamp_mode: TimestampMode,
62+
source_addr: SocketAddr,
63+
network_wait_period: std::time::Duration,
64+
socket: Option<Socket<SocketAddr, Connected>>,
6165
channels: PeerChannels,
6266

6367
peer: Peer,
@@ -86,6 +90,12 @@ enum PacketResult {
8690
Demobilize,
8791
}
8892

93+
#[derive(Debug)]
94+
enum SocketResult {
95+
Ok,
96+
Abort,
97+
}
98+
8999
impl<C, T> PeerTask<C, T>
90100
where
91101
C: 'static + NtpClock + Send + Sync,
@@ -160,7 +170,11 @@ where
160170
}
161171
}
162172

163-
match self.socket.send(packet).await {
173+
if matches!(self.setup_socket().await, SocketResult::Abort) {
174+
return PollResult::NetworkGone;
175+
}
176+
177+
match self.socket.as_mut().unwrap().send(packet).await {
164178
Err(error) => {
165179
warn!(?error, "poll message could not be sent");
166180

@@ -217,6 +231,8 @@ where
217231
}
218232
};
219233
self.channels.msg_for_system_sender.send(msg).await.ok();
234+
// No longer needed since we don't expect any more packets
235+
self.socket = None;
220236
}
221237
Err(IgnoreReason::KissDemobilize) => {
222238
info!("Demobilizing peer connection on request of remote.");
@@ -233,6 +249,32 @@ where
233249
PacketResult::Ok
234250
}
235251

252+
async fn setup_socket(&mut self) -> SocketResult {
253+
let socket_res = match self.interface {
254+
#[cfg(target_os = "linux")]
255+
Some(interface) => {
256+
open_interface_udp(
257+
interface,
258+
0, /*lets os choose*/
259+
self.timestamp_mode.as_interface_mode(),
260+
)
261+
.and_then(|socket| socket.connect(self.source_addr))
262+
}
263+
_ => connect_address(self.source_addr, self.timestamp_mode.as_general_mode()),
264+
};
265+
266+
self.socket = match socket_res {
267+
Ok(socket) => Some(socket),
268+
Err(error) => {
269+
warn!(?error, "Could not open socket");
270+
tokio::time::sleep(self.network_wait_period).await;
271+
return SocketResult::Abort;
272+
}
273+
};
274+
275+
SocketResult::Ok
276+
}
277+
236278
async fn run(&mut self, mut poll_wait: Pin<&mut T>) {
237279
loop {
238280
let mut buf = [0_u8; 1024];
@@ -252,7 +294,7 @@ where
252294
}
253295
}
254296
},
255-
result = self.socket.recv(&mut buf) => {
297+
result = async { if let Some(ref mut socket) = self.socket { socket.recv(&mut buf).await } else { std::future::pending().await }} => {
256298
tracing::debug!("accept packet");
257299
match accept_packet(result, &buf, &self.clock) {
258300
AcceptResult::Accept(packet, recv_timestamp) => {
@@ -292,7 +334,7 @@ where
292334
#[instrument(skip(clock, channels))]
293335
pub fn spawn(
294336
index: PeerId,
295-
addr: SocketAddr,
337+
source_addr: SocketAddr,
296338
interface: Option<InterfaceName>,
297339
clock: C,
298340
timestamp_mode: TimestampMode,
@@ -303,36 +345,6 @@ where
303345
) -> tokio::task::JoinHandle<()> {
304346
tokio::spawn(
305347
(async move {
306-
let socket_res = match interface {
307-
#[cfg(target_os = "linux")]
308-
Some(interface) => {
309-
open_interface_udp(
310-
interface,
311-
0, /*lets os choose*/
312-
timestamp_mode.as_interface_mode(),
313-
)
314-
.and_then(|socket| socket.connect(addr))
315-
}
316-
_ => connect_address(addr, timestamp_mode.as_general_mode()),
317-
};
318-
319-
let socket = match socket_res {
320-
Ok(socket) => socket,
321-
Err(error) => {
322-
warn!(?error, "Could not open socket");
323-
tokio::time::sleep(network_wait_period).await;
324-
channels
325-
.msg_for_system_sender
326-
.send(MsgForSystem::NetworkIssue(index))
327-
.await
328-
.ok();
329-
return;
330-
}
331-
};
332-
333-
// Unwrap should be safe because we know the socket was connected to a remote peer just before
334-
let source_addr = socket.peer_addr().unwrap();
335-
336348
let local_clock_time = NtpInstant::now();
337349
let config_snapshot = *channels.source_defaults_config_receiver.borrow_and_update();
338350
let peer = if let Some(nts) = nts {
@@ -360,7 +372,11 @@ where
360372
index,
361373
clock,
362374
channels,
363-
socket,
375+
interface,
376+
timestamp_mode,
377+
network_wait_period,
378+
source_addr,
379+
socket: None,
364380
peer,
365381
last_send_timestamp: None,
366382
last_poll_sent: Instant::now(),
@@ -431,7 +447,7 @@ mod tests {
431447
use std::{io::Cursor, net::Ipv4Addr, sync::Arc, time::Duration};
432448

433449
use ntp_proto::{NoCipher, NtpDuration, NtpLeapIndicator, NtpPacket, TimeSnapshot};
434-
use timestamped_socket::socket::{open_ip, GeneralTimestampMode};
450+
use timestamped_socket::socket::{open_ip, GeneralTimestampMode, Open};
435451
use tokio::sync::mpsc;
436452

437453
use crate::daemon::util::EPOCH_OFFSET;
@@ -552,29 +568,16 @@ mod tests {
552568
port_base: u16,
553569
) -> (
554570
PeerTask<TestClock, T>,
555-
Socket<SocketAddr, Connected>,
571+
Socket<SocketAddr, Open>,
556572
mpsc::Receiver<MsgForSystem>,
557573
) {
558574
// Note: Ports must be unique among tests to deal with parallelism, hence
559575
// port_base
560-
let socket = open_ip(
561-
SocketAddr::from((Ipv4Addr::LOCALHOST, port_base)),
562-
GeneralTimestampMode::SoftwareRecv,
563-
)
564-
.unwrap();
565-
let socket = socket
566-
.connect(SocketAddr::from((Ipv4Addr::LOCALHOST, port_base + 1)))
567-
.unwrap();
568-
569576
let test_socket = open_ip(
570-
SocketAddr::from((Ipv4Addr::LOCALHOST, port_base + 1)),
577+
SocketAddr::from((Ipv4Addr::LOCALHOST, port_base)),
571578
GeneralTimestampMode::SoftwareRecv,
572579
)
573580
.unwrap();
574-
let test_socket = test_socket
575-
.connect(SocketAddr::from((Ipv4Addr::LOCALHOST, port_base)))
576-
.unwrap();
577-
let source_addr = socket.peer_addr().unwrap();
578581

579582
let (_, system_snapshot_receiver) = tokio::sync::watch::channel(SystemSnapshot::default());
580583
let (_, synchronization_config_receiver) =
@@ -585,7 +588,7 @@ mod tests {
585588

586589
let local_clock_time = NtpInstant::now();
587590
let peer = Peer::new(
588-
source_addr,
591+
SocketAddr::from((Ipv4Addr::LOCALHOST, port_base)),
589592
local_clock_time,
590593
*peer_defaults_config_receiver.borrow_and_update(),
591594
ProtocolVersion::default(),
@@ -601,7 +604,11 @@ mod tests {
601604
synchronization_config_receiver,
602605
source_defaults_config_receiver: peer_defaults_config_receiver,
603606
},
604-
socket,
607+
source_addr: SocketAddr::from((Ipv4Addr::LOCALHOST, port_base)),
608+
interface: None,
609+
timestamp_mode: TimestampMode::KernelAll,
610+
socket: None,
611+
network_wait_period: std::time::Duration::from_secs(0),
605612
peer,
606613
last_send_timestamp: None,
607614
last_poll_sent: Instant::now(),
@@ -668,7 +675,7 @@ mod tests {
668675
let RecvResult {
669676
bytes_read: size,
670677
timestamp,
671-
..
678+
remote_addr,
672679
} = socket.recv(&mut buf).await.unwrap();
673680
assert_eq!(size, 48);
674681
let timestamp = timestamp.unwrap();
@@ -682,7 +689,7 @@ mod tests {
682689
);
683690

684691
let serialized = serialize_packet_unencryped(&send_packet);
685-
socket.send(&serialized).await.unwrap();
692+
socket.send_to(&serialized, remote_addr).await.unwrap();
686693

687694
let msg = msg_recv.recv().await.unwrap();
688695
assert!(matches!(msg, MsgForSystem::NewMeasurement(_, _, _)));
@@ -708,7 +715,7 @@ mod tests {
708715
let RecvResult {
709716
bytes_read: size,
710717
timestamp,
711-
..
718+
remote_addr,
712719
} = socket.recv(&mut buf).await.unwrap();
713720
assert_eq!(size, 48);
714721
assert!(timestamp.is_some());
@@ -717,7 +724,7 @@ mod tests {
717724
let send_packet = NtpPacket::deny_response(rec_packet);
718725
let serialized = serialize_packet_unencryped(&send_packet);
719726

720-
socket.send(&serialized).await.unwrap();
727+
socket.send_to(&serialized, remote_addr).await.unwrap();
721728

722729
let msg = msg_recv.recv().await.unwrap();
723730
assert!(matches!(msg, MsgForSystem::MustDemobilize(_)));

0 commit comments

Comments
 (0)