1
+ //! # MQTT Token Fetcher
2
+ //!
3
+ //! `MqttTokenFetcher` is responsible for fetching and managing MQTT tokens for DSH.
1
4
use std:: {
2
5
fmt:: { Display , Formatter } ,
3
6
sync:: Mutex ,
@@ -21,8 +24,9 @@ pub struct MqttTokenFetcher {
21
24
tenant_name : String ,
22
25
rest_api_key : String ,
23
26
rest_token : Mutex < RestToken > ,
27
+ rest_auth_url : String ,
24
28
mqtt_token : DashMap < String , MqttToken > , // Mapping from Client ID to MqttToken
25
- platform : Platform ,
29
+ mqtt_auth_url : String ,
26
30
//token_lifetime: Option<i32>, // TODO: Implement option of passing token lifetime to request token for specific duration
27
31
// port: Port or connection_type: Connection // TODO: Platform provides two connection options, current implemetation only provides connecting over SSL, enable WebSocket too
28
32
}
@@ -45,8 +49,9 @@ impl MqttTokenFetcher {
45
49
tenant_name,
46
50
rest_api_key,
47
51
rest_token : Mutex :: new ( rest_token) ,
52
+ rest_auth_url : platform. endpoint_rest_token ( ) . to_string ( ) ,
48
53
mqtt_token : DashMap :: new ( ) ,
49
- platform,
54
+ mqtt_auth_url : platform. endpoint_mqtt_token ( ) . to_string ( ) ,
50
55
}
51
56
}
52
57
/// Retrieves an MQTT token for the specified client ID.
@@ -66,18 +71,20 @@ impl MqttTokenFetcher {
66
71
client_id : & str ,
67
72
claims : Option < Vec < Claims > > ,
68
73
) -> Result < MqttToken , DshError > {
69
- let mut mqtt_token = self
70
- . mqtt_token
71
- . entry ( client_id. to_string ( ) )
72
- . or_insert ( self . fetch_new_mqtt_token ( client_id, claims. clone ( ) ) . await ?) ;
73
-
74
- if !mqtt_token. is_valid ( ) {
75
- * mqtt_token = self
76
- . fetch_new_mqtt_token ( client_id, claims. clone ( ) )
77
- . await
78
- . unwrap ( )
79
- } ;
80
- Ok ( mqtt_token. clone ( ) )
74
+ match self . mqtt_token . entry ( client_id. to_string ( ) ) {
75
+ dashmap:: Entry :: Occupied ( mut entry) => {
76
+ let mqtt_token = entry. get_mut ( ) ;
77
+ if !mqtt_token. is_valid ( ) {
78
+ * mqtt_token = self . fetch_new_mqtt_token ( client_id, claims) . await ?;
79
+ } ;
80
+ Ok ( mqtt_token. clone ( ) )
81
+ }
82
+ dashmap:: Entry :: Vacant ( entry) => {
83
+ let mqtt_token = self . fetch_new_mqtt_token ( client_id, claims) . await ?;
84
+ entry. insert ( mqtt_token. clone ( ) ) ;
85
+ Ok ( mqtt_token)
86
+ }
87
+ }
81
88
}
82
89
/// Fetches a new MQTT token from the platform.
83
90
///
@@ -94,7 +101,7 @@ impl MqttTokenFetcher {
94
101
95
102
if !rest_token. is_valid ( ) {
96
103
* rest_token =
97
- RestToken :: get ( & self . tenant_name , & self . rest_api_key , & self . platform ) . await ?
104
+ RestToken :: get ( & self . tenant_name , & self . rest_api_key , & self . rest_auth_url ) . await ?
98
105
}
99
106
100
107
let authorization_header = format ! ( "Bearer {}" , rest_token. raw_token) ;
@@ -103,7 +110,7 @@ impl MqttTokenFetcher {
103
110
let payload = serde_json:: to_value ( & mqtt_token_request) ?;
104
111
105
112
let response = mqtt_token_request
106
- . send ( & self . platform , & authorization_header, & payload)
113
+ . send ( & self . mqtt_auth_url , & authorization_header, & payload)
107
114
. await ?;
108
115
109
116
MqttToken :: new ( response)
@@ -202,7 +209,7 @@ impl MqttTokenRequest {
202
209
203
210
async fn send (
204
211
& self ,
205
- platform : & Platform ,
212
+ mqtt_auth_url : & str ,
206
213
authorization_header : & str ,
207
214
payload : & serde_json:: Value ,
208
215
) -> Result < String , DshError > {
@@ -215,7 +222,7 @@ impl MqttTokenRequest {
215
222
. expect ( "Failed to build reqwest client" ) ;
216
223
217
224
let response = reqwest_client
218
- . post ( platform . endpoint_mqtt_token ( ) )
225
+ . post ( mqtt_auth_url )
219
226
. header ( "Authorization" , authorization_header)
220
227
. json ( payload)
221
228
. send ( )
@@ -225,7 +232,7 @@ impl MqttTokenRequest {
225
232
Ok ( response. text ( ) . await ?)
226
233
} else {
227
234
Err ( DshError :: DshCallError {
228
- url : platform . endpoint_mqtt_token ( ) . to_string ( ) ,
235
+ url : mqtt_auth_url . to_string ( ) ,
229
236
status_code : response. status ( ) ,
230
237
error_body : response. text ( ) . await ?,
231
238
} )
@@ -283,7 +290,7 @@ impl MqttToken {
283
290
. duration_since ( UNIX_EPOCH )
284
291
. expect ( "SystemTime before UNIX EPOCH!" )
285
292
. as_secs ( ) as i32 ;
286
- self . exp >= current_unixtime - 5
293
+ self . exp >= current_unixtime + 5
287
294
}
288
295
}
289
296
@@ -292,10 +299,10 @@ impl MqttToken {
292
299
#[ serde( rename_all = "kebab-case" ) ]
293
300
struct RestTokenAttributes {
294
301
gen : i64 ,
295
- pub endpoint : String ,
302
+ endpoint : String ,
296
303
iss : String ,
297
- pub claims : RestClaims ,
298
- pub exp : i64 ,
304
+ claims : RestClaims ,
305
+ exp : i32 ,
299
306
tenant_id : String ,
300
307
}
301
308
@@ -312,7 +319,7 @@ struct DatastreamsData {}
312
319
#[ derive( Serialize , Deserialize , Debug ) ]
313
320
struct RestToken {
314
321
raw_token : String ,
315
- exp : i64 ,
322
+ exp : i32 ,
316
323
}
317
324
318
325
impl RestToken {
@@ -327,8 +334,8 @@ impl RestToken {
327
334
/// # Returns
328
335
///
329
336
/// A Result containing the created `RestToken` or a `DshError`.
330
- async fn get ( tenant : & str , api_key : & str , env : & Platform ) -> Result < RestToken , DshError > {
331
- let raw_token = Self :: fetch_token ( tenant, api_key, env ) . await . unwrap ( ) ;
337
+ async fn get ( tenant : & str , api_key : & str , auth_url : & str ) -> Result < RestToken , DshError > {
338
+ let raw_token = Self :: fetch_token ( tenant, api_key, auth_url ) . await . unwrap ( ) ;
332
339
333
340
let header_payload = extract_header_and_payload ( & raw_token) ?;
334
341
@@ -347,11 +354,11 @@ impl RestToken {
347
354
let current_unixtime = SystemTime :: now ( )
348
355
. duration_since ( UNIX_EPOCH )
349
356
. expect ( "SystemTime before UNIX EPOCH!" )
350
- . as_secs ( ) as i64 ;
351
- self . exp >= current_unixtime - 5
357
+ . as_secs ( ) as i32 ;
358
+ self . exp >= current_unixtime + 5
352
359
}
353
360
354
- async fn fetch_token ( tenant : & str , api_key : & str , env : & Platform ) -> Result < String , DshError > {
361
+ async fn fetch_token ( tenant : & str , api_key : & str , auth_url : & str ) -> Result < String , DshError > {
355
362
let json_body = json ! ( { "tenant" : tenant} ) ;
356
363
357
364
const DEFAULT_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
@@ -364,7 +371,7 @@ impl RestToken {
364
371
365
372
let rest_client = reqwest_client;
366
373
let response = rest_client
367
- . post ( env . endpoint_rest_token ( ) )
374
+ . post ( auth_url )
368
375
. header ( "apikey" , api_key)
369
376
. json ( & json_body)
370
377
. send ( )
@@ -375,7 +382,7 @@ impl RestToken {
375
382
match status {
376
383
reqwest:: StatusCode :: OK => Ok ( body_text) ,
377
384
_ => Err ( DshError :: DshCallError {
378
- url : env . endpoint_rest_api ( ) . to_string ( ) ,
385
+ url : auth_url . to_string ( ) ,
379
386
status_code : status,
380
387
error_body : body_text,
381
388
} ) ,
@@ -437,6 +444,33 @@ fn decode_base64(payload: &str) -> Result<Vec<u8>, DshError> {
437
444
mod tests {
438
445
use super :: * ;
439
446
447
+ fn create_valid_fetcher ( ) -> MqttTokenFetcher {
448
+ let exp_time = SystemTime :: now ( )
449
+ . duration_since ( UNIX_EPOCH )
450
+ . unwrap ( )
451
+ . as_secs ( ) as i32
452
+ + 3600 ;
453
+ println ! ( "exp_time: {}" , exp_time) ;
454
+ let rest_token: RestToken = RestToken {
455
+ exp : exp_time as i32 ,
456
+ raw_token : "valid.token.payload" . to_string ( ) ,
457
+ } ;
458
+ let mqtt_token = MqttToken {
459
+ exp : exp_time,
460
+ raw_token : "valid.token.payload" . to_string ( ) ,
461
+ } ;
462
+ let mqtt_token_map = DashMap :: new ( ) ;
463
+ mqtt_token_map. insert ( "test_client" . to_string ( ) , mqtt_token. clone ( ) ) ;
464
+ MqttTokenFetcher {
465
+ tenant_name : "test_tenant" . to_string ( ) ,
466
+ rest_api_key : "test_api_key" . to_string ( ) ,
467
+ rest_token : Mutex :: new ( rest_token) ,
468
+ rest_auth_url : "test_auth_url" . to_string ( ) ,
469
+ mqtt_token : mqtt_token_map,
470
+ mqtt_auth_url : "test_auth_url" . to_string ( ) ,
471
+ }
472
+ }
473
+
440
474
#[ tokio:: test]
441
475
async fn test_mqtt_token_fetcher_new ( ) {
442
476
let tenant_name = "test_tenant" . to_string ( ) ;
@@ -448,6 +482,13 @@ mod tests {
448
482
assert ! ( fetcher. mqtt_token. is_empty( ) ) ;
449
483
}
450
484
485
+ #[ tokio:: test]
486
+ async fn test_mqtt_token_fetcher_get_token ( ) {
487
+ let fetcher = create_valid_fetcher ( ) ;
488
+ let token = fetcher. get_token ( "test_client" , None ) . await . unwrap ( ) ;
489
+ assert_eq ! ( token. raw_token, "valid.token.payload" ) ;
490
+ }
491
+
451
492
#[ test]
452
493
fn test_claims_new ( ) {
453
494
let resource = Resource :: new (
@@ -492,18 +533,66 @@ mod tests {
492
533
493
534
assert ! ( token. is_valid( ) ) ;
494
535
}
536
+ #[ test]
537
+ fn test_mqtt_token_is_invalid ( ) {
538
+ let raw_token = "valid.token.payload" . to_string ( ) ;
539
+ let token = MqttToken {
540
+ exp : SystemTime :: now ( )
541
+ . duration_since ( UNIX_EPOCH )
542
+ . unwrap ( )
543
+ . as_secs ( ) as i32 ,
544
+ raw_token,
545
+ } ;
546
+
547
+ assert ! ( !token. is_valid( ) ) ;
548
+ }
495
549
496
550
#[ test]
497
551
fn test_rest_token_is_valid ( ) {
498
552
let token = RestToken {
499
553
exp : SystemTime :: now ( )
500
554
. duration_since ( UNIX_EPOCH )
501
555
. unwrap ( )
502
- . as_secs ( ) as i64
556
+ . as_secs ( ) as i32
503
557
+ 3600 ,
504
558
raw_token : "valid.token.payload" . to_string ( ) ,
505
559
} ;
506
560
507
561
assert ! ( token. is_valid( ) ) ;
508
562
}
563
+
564
+ #[ test]
565
+ fn test_rest_token_is_invalid ( ) {
566
+ let token = RestToken {
567
+ exp : SystemTime :: now ( )
568
+ . duration_since ( UNIX_EPOCH )
569
+ . unwrap ( )
570
+ . as_secs ( ) as i32 ,
571
+ raw_token : "valid.token.payload" . to_string ( ) ,
572
+ } ;
573
+
574
+ assert ! ( !token. is_valid( ) ) ;
575
+ }
576
+
577
+ #[ test]
578
+ fn test_rest_token_default_is_invalid ( ) {
579
+ let token = RestToken :: default ( ) ;
580
+
581
+ assert ! ( !token. is_valid( ) ) ;
582
+ }
583
+
584
+ #[ test]
585
+ fn test_extract_header_and_payload ( ) {
586
+ let raw = "header.payload.signature" ;
587
+ let result = extract_header_and_payload ( raw) . unwrap ( ) ;
588
+ assert_eq ! ( result, "payload" ) ;
589
+
590
+ let raw = "header.payload" ;
591
+ let result = extract_header_and_payload ( raw) . unwrap ( ) ;
592
+ assert_eq ! ( result, "payload" ) ;
593
+
594
+ let raw = "header" ;
595
+ let result = extract_header_and_payload ( raw) ;
596
+ assert ! ( result. is_err( ) ) ;
597
+ }
509
598
}
0 commit comments