1
1
// TODO: rustls TLS impl for preventing MITM attacks
2
2
3
- use std:: net:: SocketAddr ;
3
+ use std:: { net:: SocketAddr , sync :: Arc } ;
4
4
5
5
use bitcode:: { Decode , Encode } ;
6
6
use qb_core:: common:: QBDeviceId ;
@@ -9,8 +9,10 @@ use qb_ext::{
9
9
interface:: { QBIChannel , QBIContext , QBIHostMessage , QBIMessage , QBISetup , QBISlaveMessage } ,
10
10
} ;
11
11
use qb_proto:: QBP ;
12
+ use rustls:: { pki_types:: ServerName , ClientConfig , RootCertStore , ServerConfig } ;
12
13
use serde:: { Deserialize , Serialize } ;
13
14
use tokio:: net:: { TcpListener , TcpSocket , TcpStream } ;
15
+ use tokio_rustls:: { TlsAcceptor , TlsConnector , TlsStream } ;
14
16
use tracing:: { error, info, warn} ;
15
17
16
18
/// A hook which listens for incoming connections and yields
@@ -49,12 +51,20 @@ impl QBHServerSocket {
49
51
50
52
impl QBHContext < QBIServerSocket > for QBHServerSocket {
51
53
async fn run ( self , init : QBHInit < QBIServerSocket > ) {
54
+ let root_cert_store = RootCertStore :: empty ( ) ;
55
+ // TODO: add root certificate
56
+ let config = ServerConfig :: builder ( )
57
+ . with_no_client_auth ( )
58
+ . with_single_cert ( todo ! ( ) , todo ! ( ) )
59
+ . unwrap ( ) ;
60
+
52
61
loop {
53
62
// listen on incoming connections
54
63
let ( stream, addr) = self . listener . accept ( ) . await . unwrap ( ) ;
55
64
info ! ( "connected: {}" , addr) ;
56
65
// yield a [QBIServerSocket]
57
66
init. attach ( QBIServerSocket {
67
+ config,
58
68
stream,
59
69
auth : self . auth . clone ( ) ,
60
70
} )
@@ -77,7 +87,17 @@ impl QBIContext for QBIClientSocket {
77
87
78
88
let socket = TcpSocket :: new_v4 ( ) . unwrap ( ) ;
79
89
let addr = self . addr . parse ( ) . unwrap ( ) ;
80
- let mut stream = socket. connect ( addr) . await . unwrap ( ) ;
90
+ let stream = socket. connect ( addr) . await . unwrap ( ) ;
91
+
92
+ let root_cert_store = RootCertStore :: empty ( ) ;
93
+ // TODO: add root certificate
94
+ let config = ClientConfig :: builder ( )
95
+ . with_root_certificates ( root_cert_store)
96
+ . with_no_client_auth ( ) ;
97
+ let connector = TlsConnector :: from ( Arc :: new ( config) ) ;
98
+ let dnsname = ServerName :: try_from ( "quixbyte.application" ) . unwrap ( ) ;
99
+ let mut stream = connector. connect ( dnsname, stream) . await . unwrap ( ) ;
100
+
81
101
let mut protocol = QBP :: default ( ) ;
82
102
protocol. negotiate ( & mut stream) . await . unwrap ( ) ;
83
103
protocol
@@ -90,7 +110,7 @@ impl QBIContext for QBIClientSocket {
90
110
let runner = Runner {
91
111
host_id,
92
112
com,
93
- stream,
113
+ stream : TlsStream :: Client ( stream ) ,
94
114
protocol,
95
115
} ;
96
116
@@ -107,13 +127,18 @@ impl<'a> QBISetup<'a> for QBIClientSocket {
107
127
#[ derive( Debug ) ]
108
128
pub struct QBIServerSocket {
109
129
pub stream : TcpStream ,
130
+ pub config : ServerConfig ,
110
131
/// An authentication token sent on boot
111
132
pub auth : Vec < u8 > ,
112
133
}
113
134
114
135
impl QBIContext for QBIServerSocket {
115
136
async fn run ( self , host_id : QBDeviceId , com : QBIChannel ) {
116
- let mut stream = self . stream ;
137
+ let stream = self . stream ;
138
+
139
+ let acceptor = TlsAcceptor :: from ( Arc :: new ( self . config ) ) ;
140
+ let mut stream = acceptor. accept ( stream) . await . unwrap ( ) ;
141
+
117
142
let mut protocol = QBP :: default ( ) ;
118
143
protocol. negotiate ( & mut stream) . await . unwrap ( ) ;
119
144
let auth = protocol. recv_payload ( & mut stream) . await . unwrap ( ) ;
@@ -125,7 +150,7 @@ impl QBIContext for QBIServerSocket {
125
150
let runner = Runner {
126
151
host_id,
127
152
com,
128
- stream,
153
+ stream : TlsStream :: Server ( stream ) ,
129
154
protocol,
130
155
} ;
131
156
@@ -136,7 +161,7 @@ impl QBIContext for QBIServerSocket {
136
161
struct Runner {
137
162
host_id : QBDeviceId ,
138
163
com : QBIChannel ,
139
- stream : TcpStream ,
164
+ stream : TlsStream < TcpStream > ,
140
165
protocol : QBP ,
141
166
}
142
167
0 commit comments