@@ -57,7 +57,11 @@ pub(crate) struct PeerTask<C: 'static + NtpClock + Send, T: Wait> {
57
57
_wait : PhantomData < T > ,
58
58
index : PeerId ,
59
59
clock : C ,
60
- socket : Socket < SocketAddr , Connected > ,
60
+ interface : Option < InterfaceName > ,
61
+ timestamp_mode : TimestampMode ,
62
+ source_addr : SocketAddr ,
63
+ network_wait_period : std:: time:: Duration ,
64
+ socket : Option < Socket < SocketAddr , Connected > > ,
61
65
channels : PeerChannels ,
62
66
63
67
peer : Peer ,
@@ -86,6 +90,12 @@ enum PacketResult {
86
90
Demobilize ,
87
91
}
88
92
93
+ #[ derive( Debug ) ]
94
+ enum SocketResult {
95
+ Ok ,
96
+ Abort ,
97
+ }
98
+
89
99
impl < C , T > PeerTask < C , T >
90
100
where
91
101
C : ' static + NtpClock + Send + Sync ,
@@ -160,7 +170,11 @@ where
160
170
}
161
171
}
162
172
163
- match self . socket . send ( packet) . await {
173
+ if matches ! ( self . setup_socket( ) . await , SocketResult :: Abort ) {
174
+ return PollResult :: NetworkGone ;
175
+ }
176
+
177
+ match self . socket . as_mut ( ) . unwrap ( ) . send ( packet) . await {
164
178
Err ( error) => {
165
179
warn ! ( ?error, "poll message could not be sent" ) ;
166
180
@@ -217,6 +231,8 @@ where
217
231
}
218
232
} ;
219
233
self . channels . msg_for_system_sender . send ( msg) . await . ok ( ) ;
234
+ // No longer needed since we don't expect any more packets
235
+ self . socket = None ;
220
236
}
221
237
Err ( IgnoreReason :: KissDemobilize ) => {
222
238
info ! ( "Demobilizing peer connection on request of remote." ) ;
@@ -233,6 +249,32 @@ where
233
249
PacketResult :: Ok
234
250
}
235
251
252
+ async fn setup_socket ( & mut self ) -> SocketResult {
253
+ let socket_res = match self . interface {
254
+ #[ cfg( target_os = "linux" ) ]
255
+ Some ( interface) => {
256
+ open_interface_udp (
257
+ interface,
258
+ 0 , /*lets os choose*/
259
+ self . timestamp_mode . as_interface_mode ( ) ,
260
+ )
261
+ . and_then ( |socket| socket. connect ( self . source_addr ) )
262
+ }
263
+ _ => connect_address ( self . source_addr , self . timestamp_mode . as_general_mode ( ) ) ,
264
+ } ;
265
+
266
+ self . socket = match socket_res {
267
+ Ok ( socket) => Some ( socket) ,
268
+ Err ( error) => {
269
+ warn ! ( ?error, "Could not open socket" ) ;
270
+ tokio:: time:: sleep ( self . network_wait_period ) . await ;
271
+ return SocketResult :: Abort ;
272
+ }
273
+ } ;
274
+
275
+ SocketResult :: Ok
276
+ }
277
+
236
278
async fn run ( & mut self , mut poll_wait : Pin < & mut T > ) {
237
279
loop {
238
280
let mut buf = [ 0_u8 ; 1024 ] ;
@@ -252,7 +294,7 @@ where
252
294
}
253
295
}
254
296
} ,
255
- result = self . socket. recv( & mut buf) => {
297
+ result = async { if let Some ( ref mut socket ) = self . socket { socket . recv( & mut buf) . await } else { std :: future :: pending ( ) . await } } => {
256
298
tracing:: debug!( "accept packet" ) ;
257
299
match accept_packet( result, & buf, & self . clock) {
258
300
AcceptResult :: Accept ( packet, recv_timestamp) => {
@@ -292,7 +334,7 @@ where
292
334
#[ instrument( skip( clock, channels) ) ]
293
335
pub fn spawn (
294
336
index : PeerId ,
295
- addr : SocketAddr ,
337
+ source_addr : SocketAddr ,
296
338
interface : Option < InterfaceName > ,
297
339
clock : C ,
298
340
timestamp_mode : TimestampMode ,
@@ -303,36 +345,6 @@ where
303
345
) -> tokio:: task:: JoinHandle < ( ) > {
304
346
tokio:: spawn (
305
347
( async move {
306
- let socket_res = match interface {
307
- #[ cfg( target_os = "linux" ) ]
308
- Some ( interface) => {
309
- open_interface_udp (
310
- interface,
311
- 0 , /*lets os choose*/
312
- timestamp_mode. as_interface_mode ( ) ,
313
- )
314
- . and_then ( |socket| socket. connect ( addr) )
315
- }
316
- _ => connect_address ( addr, timestamp_mode. as_general_mode ( ) ) ,
317
- } ;
318
-
319
- let socket = match socket_res {
320
- Ok ( socket) => socket,
321
- Err ( error) => {
322
- warn ! ( ?error, "Could not open socket" ) ;
323
- tokio:: time:: sleep ( network_wait_period) . await ;
324
- channels
325
- . msg_for_system_sender
326
- . send ( MsgForSystem :: NetworkIssue ( index) )
327
- . await
328
- . ok ( ) ;
329
- return ;
330
- }
331
- } ;
332
-
333
- // Unwrap should be safe because we know the socket was connected to a remote peer just before
334
- let source_addr = socket. peer_addr ( ) . unwrap ( ) ;
335
-
336
348
let local_clock_time = NtpInstant :: now ( ) ;
337
349
let config_snapshot = * channels. source_defaults_config_receiver . borrow_and_update ( ) ;
338
350
let peer = if let Some ( nts) = nts {
@@ -360,7 +372,11 @@ where
360
372
index,
361
373
clock,
362
374
channels,
363
- socket,
375
+ interface,
376
+ timestamp_mode,
377
+ network_wait_period,
378
+ source_addr,
379
+ socket : None ,
364
380
peer,
365
381
last_send_timestamp : None ,
366
382
last_poll_sent : Instant :: now ( ) ,
@@ -431,7 +447,7 @@ mod tests {
431
447
use std:: { io:: Cursor , net:: Ipv4Addr , sync:: Arc , time:: Duration } ;
432
448
433
449
use ntp_proto:: { NoCipher , NtpDuration , NtpLeapIndicator , NtpPacket , TimeSnapshot } ;
434
- use timestamped_socket:: socket:: { open_ip, GeneralTimestampMode } ;
450
+ use timestamped_socket:: socket:: { open_ip, GeneralTimestampMode , Open } ;
435
451
use tokio:: sync:: mpsc;
436
452
437
453
use crate :: daemon:: util:: EPOCH_OFFSET ;
@@ -552,29 +568,16 @@ mod tests {
552
568
port_base : u16 ,
553
569
) -> (
554
570
PeerTask < TestClock , T > ,
555
- Socket < SocketAddr , Connected > ,
571
+ Socket < SocketAddr , Open > ,
556
572
mpsc:: Receiver < MsgForSystem > ,
557
573
) {
558
574
// Note: Ports must be unique among tests to deal with parallelism, hence
559
575
// port_base
560
- let socket = open_ip (
561
- SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base) ) ,
562
- GeneralTimestampMode :: SoftwareRecv ,
563
- )
564
- . unwrap ( ) ;
565
- let socket = socket
566
- . connect ( SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base + 1 ) ) )
567
- . unwrap ( ) ;
568
-
569
576
let test_socket = open_ip (
570
- SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base + 1 ) ) ,
577
+ SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base) ) ,
571
578
GeneralTimestampMode :: SoftwareRecv ,
572
579
)
573
580
. unwrap ( ) ;
574
- let test_socket = test_socket
575
- . connect ( SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base) ) )
576
- . unwrap ( ) ;
577
- let source_addr = socket. peer_addr ( ) . unwrap ( ) ;
578
581
579
582
let ( _, system_snapshot_receiver) = tokio:: sync:: watch:: channel ( SystemSnapshot :: default ( ) ) ;
580
583
let ( _, synchronization_config_receiver) =
@@ -585,7 +588,7 @@ mod tests {
585
588
586
589
let local_clock_time = NtpInstant :: now ( ) ;
587
590
let peer = Peer :: new (
588
- source_addr ,
591
+ SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base ) ) ,
589
592
local_clock_time,
590
593
* peer_defaults_config_receiver. borrow_and_update ( ) ,
591
594
ProtocolVersion :: default ( ) ,
@@ -601,7 +604,11 @@ mod tests {
601
604
synchronization_config_receiver,
602
605
source_defaults_config_receiver : peer_defaults_config_receiver,
603
606
} ,
604
- socket,
607
+ source_addr : SocketAddr :: from ( ( Ipv4Addr :: LOCALHOST , port_base) ) ,
608
+ interface : None ,
609
+ timestamp_mode : TimestampMode :: KernelAll ,
610
+ socket : None ,
611
+ network_wait_period : std:: time:: Duration :: from_secs ( 0 ) ,
605
612
peer,
606
613
last_send_timestamp : None ,
607
614
last_poll_sent : Instant :: now ( ) ,
@@ -668,7 +675,7 @@ mod tests {
668
675
let RecvResult {
669
676
bytes_read : size,
670
677
timestamp,
671
- ..
678
+ remote_addr ,
672
679
} = socket. recv ( & mut buf) . await . unwrap ( ) ;
673
680
assert_eq ! ( size, 48 ) ;
674
681
let timestamp = timestamp. unwrap ( ) ;
@@ -682,7 +689,7 @@ mod tests {
682
689
) ;
683
690
684
691
let serialized = serialize_packet_unencryped ( & send_packet) ;
685
- socket. send ( & serialized) . await . unwrap ( ) ;
692
+ socket. send_to ( & serialized, remote_addr ) . await . unwrap ( ) ;
686
693
687
694
let msg = msg_recv. recv ( ) . await . unwrap ( ) ;
688
695
assert ! ( matches!( msg, MsgForSystem :: NewMeasurement ( _, _, _) ) ) ;
@@ -708,7 +715,7 @@ mod tests {
708
715
let RecvResult {
709
716
bytes_read : size,
710
717
timestamp,
711
- ..
718
+ remote_addr ,
712
719
} = socket. recv ( & mut buf) . await . unwrap ( ) ;
713
720
assert_eq ! ( size, 48 ) ;
714
721
assert ! ( timestamp. is_some( ) ) ;
@@ -717,7 +724,7 @@ mod tests {
717
724
let send_packet = NtpPacket :: deny_response ( rec_packet) ;
718
725
let serialized = serialize_packet_unencryped ( & send_packet) ;
719
726
720
- socket. send ( & serialized) . await . unwrap ( ) ;
727
+ socket. send_to ( & serialized, remote_addr ) . await . unwrap ( ) ;
721
728
722
729
let msg = msg_recv. recv ( ) . await . unwrap ( ) ;
723
730
assert ! ( matches!( msg, MsgForSystem :: MustDemobilize ( _) ) ) ;
0 commit comments