27
27
//! requests to a single node is also limited.
28
28
29
29
use std:: {
30
- collections:: { hash_map:: Entry , HashMap , HashSet } ,
30
+ collections:: {
31
+ hash_map:: { self , Entry } ,
32
+ HashMap , HashSet ,
33
+ } ,
31
34
fmt,
35
+ future:: Future ,
32
36
num:: NonZeroUsize ,
33
37
sync:: {
34
38
atomic:: { AtomicU64 , Ordering } ,
@@ -46,7 +50,7 @@ use tokio::{
46
50
sync:: { mpsc, oneshot} ,
47
51
task:: JoinSet ,
48
52
} ;
49
- use tokio_util:: { sync:: CancellationToken , time:: delay_queue} ;
53
+ use tokio_util:: { either :: Either , sync:: CancellationToken , time:: delay_queue} ;
50
54
use tracing:: { debug, error_span, trace, warn, Instrument } ;
51
55
52
56
use crate :: {
@@ -75,13 +79,15 @@ pub struct IntentId(pub u64);
75
79
/// Trait modeling a dialer. This allows for IO-less testing.
76
80
pub trait Dialer : Stream < Item = ( NodeId , anyhow:: Result < Self :: Connection > ) > + Unpin {
77
81
/// Type of connections returned by the Dialer.
78
- type Connection : Clone ;
82
+ type Connection : Clone + ' static ;
79
83
/// Dial a node.
80
84
fn queue_dial ( & mut self , node_id : NodeId ) ;
81
85
/// Get the number of dialing nodes.
82
86
fn pending_count ( & self ) -> usize ;
83
87
/// Check if a node is being dialed.
84
88
fn is_pending ( & self , node : NodeId ) -> bool ;
89
+ /// Get the node id of our node.
90
+ fn node_id ( & self ) -> NodeId ;
85
91
}
86
92
87
93
/// Signals what should be done with the request when it fails.
@@ -97,20 +103,39 @@ pub enum FailureAction {
97
103
RetryLater ( anyhow:: Error ) ,
98
104
}
99
105
100
- /// Future of a get request.
101
- type GetFut = BoxedLocal < InternalDownloadResult > ;
106
+ /// Future of a get request, for the checking stage.
107
+ type GetStartFut < N > = BoxedLocal < Result < GetOutput < N > , FailureAction > > ;
108
+ /// Future of a get request, for the downloading stage.
109
+ type GetProceedFut = BoxedLocal < InternalDownloadResult > ;
102
110
103
111
/// Trait modelling performing a single request over a connection. This allows for IO-less testing.
104
112
pub trait Getter {
105
113
/// Type of connections the Getter requires to perform a download.
106
- type Connection ;
107
- /// Return a future that performs the download using the given connection.
114
+ type Connection : ' static ;
115
+ /// Type of the intermediary state returned from [`Self::get`] if a connection is needed.
116
+ type NeedsConn : NeedsConn < Self :: Connection > ;
117
+ /// Returns a future that checks the local store if the request is already complete, returning
118
+ /// a struct implementing [`NeedsConn`] if we need a network connection to proceed.
108
119
fn get (
109
120
& mut self ,
110
121
kind : DownloadKind ,
111
- conn : Self :: Connection ,
112
122
progress_sender : BroadcastProgressSender ,
113
- ) -> GetFut ;
123
+ ) -> GetStartFut < Self :: NeedsConn > ;
124
+ }
125
+
126
+ /// Trait modelling the intermediary state when a connection is needed to proceed.
127
+ pub trait NeedsConn < C > : std:: fmt:: Debug + ' static {
128
+ /// Proceeds the download with the given connection.
129
+ fn proceed ( self , conn : C ) -> GetProceedFut ;
130
+ }
131
+
132
+ /// Output returned from [`Getter::get`].
133
+ #[ derive( Debug ) ]
134
+ pub enum GetOutput < N > {
135
+ /// The request is already complete in the local store.
136
+ Complete ( Stats ) ,
137
+ /// The request needs a connection to continue.
138
+ NeedsConn ( N ) ,
114
139
}
115
140
116
141
/// Concurrency limits for the [`Downloader`].
@@ -280,7 +305,7 @@ pub struct DownloadHandle {
280
305
receiver : oneshot:: Receiver < ExternalDownloadResult > ,
281
306
}
282
307
283
- impl std :: future :: Future for DownloadHandle {
308
+ impl Future for DownloadHandle {
284
309
type Output = ExternalDownloadResult ;
285
310
286
311
fn poll (
@@ -424,10 +449,12 @@ struct IntentHandlers {
424
449
}
425
450
426
451
/// Information about a request.
427
- #[ derive( Debug , Default ) ]
428
- struct RequestInfo {
452
+ #[ derive( Debug ) ]
453
+ struct RequestInfo < NC > {
429
454
/// Registered intents with progress senders and result callbacks.
430
455
intents : HashMap < IntentId , IntentHandlers > ,
456
+ progress_sender : BroadcastProgressSender ,
457
+ get_state : Option < NC > ,
431
458
}
432
459
433
460
/// Information about a request in progress.
@@ -529,7 +556,7 @@ struct Service<G: Getter, D: Dialer> {
529
556
/// Queue of pending downloads.
530
557
queue : Queue ,
531
558
/// Information about pending and active requests.
532
- requests : HashMap < DownloadKind , RequestInfo > ,
559
+ requests : HashMap < DownloadKind , RequestInfo < G :: NeedsConn > > ,
533
560
/// State of running downloads.
534
561
active_requests : HashMap < DownloadKind , ActiveRequestInfo > ,
535
562
/// Tasks for currently running downloads.
@@ -666,48 +693,85 @@ impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
666
693
on_progress : progress,
667
694
} ;
668
695
669
- // early exit if no providers.
670
- if nodes. is_empty ( ) && self . providers . get_candidates ( & kind. hash ( ) ) . next ( ) . is_none ( ) {
671
- self . finalize_download (
672
- kind,
673
- [ ( intent_id, intent_handlers) ] . into ( ) ,
674
- Err ( DownloadError :: NoProviders ) ,
675
- ) ;
676
- return ;
677
- }
678
-
679
696
// add the nodes to the provider map
680
- let updated = self
681
- . providers
682
- . add_hash_with_nodes ( kind. hash ( ) , nodes. iter ( ) . map ( |n| n. node_id ) ) ;
697
+ // (skip the node id of our own node - we should never attempt to download from ourselves)
698
+ let node_ids = nodes
699
+ . iter ( )
700
+ . map ( |n| n. node_id )
701
+ . filter ( |node_id| * node_id != self . dialer . node_id ( ) ) ;
702
+ let updated = self . providers . add_hash_with_nodes ( kind. hash ( ) , node_ids) ;
683
703
684
704
// queue the transfer (if not running) or attach to transfer progress (if already running)
685
- if self . active_requests . contains_key ( & kind) {
686
- // the transfer is already running, so attach the progress sender
687
- if let Some ( on_progress) = & intent_handlers. on_progress {
688
- // this is async because it sends the current state over the progress channel
689
- if let Err ( err) = self
690
- . progress_tracker
691
- . subscribe ( kind, on_progress. clone ( ) )
692
- . await
693
- {
694
- debug ! ( ?err, %kind, "failed to subscribe progress sender to transfer" ) ;
705
+ match self . requests . entry ( kind) {
706
+ hash_map:: Entry :: Occupied ( mut entry) => {
707
+ if let Some ( on_progress) = & intent_handlers. on_progress {
708
+ // this is async because it sends the current state over the progress channel
709
+ if let Err ( err) = self
710
+ . progress_tracker
711
+ . subscribe ( kind, on_progress. clone ( ) )
712
+ . await
713
+ {
714
+ debug ! ( ?err, %kind, "failed to subscribe progress sender to transfer" ) ;
715
+ }
695
716
}
717
+ entry. get_mut ( ) . intents . insert ( intent_id, intent_handlers) ;
696
718
}
697
- } else {
698
- // the transfer is not running.
699
- if updated && self . queue . is_parked ( & kind) {
700
- // the transfer is on hold for pending retries, and we added new nodes, so move back to queue.
701
- self . queue . unpark ( & kind) ;
702
- } else if !self . queue . contains ( & kind) {
703
- // the transfer is not yet queued: add to queue.
719
+ hash_map:: Entry :: Vacant ( entry) => {
720
+ tracing:: warn!( "is new, queue" ) ;
721
+ let progress_sender = self . progress_tracker . track (
722
+ kind,
723
+ intent_handlers
724
+ . on_progress
725
+ . clone ( )
726
+ . into_iter ( )
727
+ . collect :: < Vec < _ > > ( ) ,
728
+ ) ;
729
+
730
+ let get_state = match self . getter . get ( kind, progress_sender. clone ( ) ) . await {
731
+ Err ( _err) => {
732
+ self . finalize_download (
733
+ kind,
734
+ [ ( intent_id, intent_handlers) ] . into ( ) ,
735
+ // TODO: add better error variant? this is only triggered if the local
736
+ // store failed with local IO.
737
+ Err ( DownloadError :: DownloadFailed ) ,
738
+ ) ;
739
+ return ;
740
+ }
741
+ Ok ( GetOutput :: Complete ( stats) ) => {
742
+ self . finalize_download (
743
+ kind,
744
+ [ ( intent_id, intent_handlers) ] . into ( ) ,
745
+ Ok ( stats) ,
746
+ ) ;
747
+ return ;
748
+ }
749
+ Ok ( GetOutput :: NeedsConn ( state) ) => {
750
+ // early exit if no providers.
751
+ if self . providers . get_candidates ( & kind. hash ( ) ) . next ( ) . is_none ( ) {
752
+ self . finalize_download (
753
+ kind,
754
+ [ ( intent_id, intent_handlers) ] . into ( ) ,
755
+ Err ( DownloadError :: NoProviders ) ,
756
+ ) ;
757
+ return ;
758
+ }
759
+ state
760
+ }
761
+ } ;
762
+ entry. insert ( RequestInfo {
763
+ intents : [ ( intent_id, intent_handlers) ] . into_iter ( ) . collect ( ) ,
764
+ progress_sender,
765
+ get_state : Some ( get_state) ,
766
+ } ) ;
704
767
self . queue . insert ( kind) ;
705
768
}
706
769
}
707
770
708
- // store the request info
709
- let request_info = self . requests . entry ( kind) . or_default ( ) ;
710
- request_info. intents . insert ( intent_id, intent_handlers) ;
771
+ if updated && self . queue . is_parked ( & kind) {
772
+ // the transfer is on hold for pending retries, and we added new nodes, so move back to queue.
773
+ self . queue . unpark ( & kind) ;
774
+ }
711
775
}
712
776
713
777
/// Cancels a download intent.
@@ -860,7 +924,6 @@ impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
860
924
) {
861
925
self . progress_tracker . remove ( & kind) ;
862
926
self . remove_hash_if_not_queued ( & kind. hash ( ) ) ;
863
- let result = result. map_err ( |_| DownloadError :: DownloadFailed ) ;
864
927
for ( _id, handlers) in intents. into_iter ( ) {
865
928
handlers. on_finish . send ( result. clone ( ) ) . ok ( ) ;
866
929
}
@@ -1082,14 +1145,9 @@ impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
1082
1145
/// Panics if hash is not in self.requests or node is not in self.nodes.
1083
1146
fn start_download ( & mut self , kind : DownloadKind , node : NodeId ) {
1084
1147
let node_info = self . connected_nodes . get_mut ( & node) . expect ( "node exists" ) ;
1085
- let request_info = self . requests . get ( & kind) . expect ( "hash exists" ) ;
1086
-
1087
- // create a progress sender and subscribe all intents to the progress sender
1088
- let subscribers = request_info
1089
- . intents
1090
- . values ( )
1091
- . flat_map ( |state| state. on_progress . clone ( ) ) ;
1092
- let progress_sender = self . progress_tracker . track ( kind, subscribers) ;
1148
+ let request_info = self . requests . get_mut ( & kind) . expect ( "request exists" ) ;
1149
+ let progress = request_info. progress_sender . clone ( ) ;
1150
+ // .expect("queued state exists");
1093
1151
1094
1152
// create the active request state
1095
1153
let cancellation = CancellationToken :: new ( ) ;
@@ -1098,17 +1156,32 @@ impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
1098
1156
node,
1099
1157
} ;
1100
1158
let conn = node_info. conn . clone ( ) ;
1101
- let get_fut = self . getter . get ( kind, conn, progress_sender) ;
1159
+
1160
+ // If this is the first provider node we try, we have an initial state
1161
+ // from starting the generator in Self::handle_queue_new_download.
1162
+ // If this not the first provider node we try, we have to recreate the generator, because
1163
+ // we can only resume it once.
1164
+ let get_state = match request_info. get_state . take ( ) {
1165
+ Some ( state) => Either :: Left ( async move { Ok ( GetOutput :: NeedsConn ( state) ) } ) ,
1166
+ None => Either :: Right ( self . getter . get ( kind, progress) ) ,
1167
+ } ;
1102
1168
let fut = async move {
1103
1169
// NOTE: it's an open question if we should do timeouts at this point. Considerations from @Frando:
1104
1170
// > at this stage we do not know the size of the download, so the timeout would have
1105
1171
// > to be so large that it won't be useful for non-huge downloads. At the same time,
1106
1172
// > this means that a super slow node would block a download from succeeding for a long
1107
1173
// > time, while faster nodes could be readily available.
1108
1174
// As a conclusion, timeouts should be added only after downloads are known to be bounded
1175
+ let fut = async move {
1176
+ match get_state. await ? {
1177
+ GetOutput :: Complete ( stats) => Ok ( stats) ,
1178
+ GetOutput :: NeedsConn ( state) => state. proceed ( conn) . await ,
1179
+ }
1180
+ } ;
1181
+ tokio:: pin!( fut) ;
1109
1182
let res = tokio:: select! {
1110
1183
_ = cancellation. cancelled( ) => Err ( FailureAction :: AllIntentsDropped ) ,
1111
- res = get_fut => res
1184
+ res = & mut fut => res
1112
1185
} ;
1113
1186
trace ! ( "transfer finished" ) ;
1114
1187
@@ -1433,4 +1506,8 @@ impl Dialer for iroh_net::dialer::Dialer {
1433
1506
fn is_pending ( & self , node : NodeId ) -> bool {
1434
1507
self . is_pending ( node)
1435
1508
}
1509
+
1510
+ fn node_id ( & self ) -> NodeId {
1511
+ self . endpoint ( ) . node_id ( )
1512
+ }
1436
1513
}
0 commit comments