@@ -25,61 +25,65 @@ import (
25
25
"time"
26
26
)
27
27
28
+ const (
29
+ idType = 0 // address type index
30
+ idIP0 = 1 // ip addres start index
31
+ idDmLen = 1 // domain address length index
32
+ idDm0 = 2 // domain address start index
33
+
34
+ typeIPv4 = 1 // type is ipv4 address
35
+ typeDm = 3 // type is domain address
36
+ typeIPv6 = 4 // type is ipv6 address
37
+
38
+ lenIPv4 = net .IPv4len + 2 // ipv4 + 2port
39
+ lenIPv6 = net .IPv6len + 2 // ipv6 + 2port
40
+ lenDmBase = 2 // 1addrLen + 2port, plus addrLen
41
+ lenHmacSha1 = 10
42
+ )
43
+
28
44
var ssdebug ss.DebugLog
29
45
30
- func getRequest (conn * ss.Conn ) (host string , extra []byte , err error ) {
31
- const (
32
- idType = 0 // address type index
33
- idIP0 = 1 // ip addres start index
34
- idDmLen = 1 // domain address length index
35
- idDm0 = 2 // domain address start index
36
-
37
- typeIPv4 = 1 // type is ipv4 address
38
- typeDm = 3 // type is domain address
39
- typeIPv6 = 4 // type is ipv6 address
40
-
41
- lenIPv4 = 1 + net .IPv4len + 2 // 1addrType + ipv4 + 2port
42
- lenIPv6 = 1 + net .IPv6len + 2 // 1addrType + ipv6 + 2port
43
- lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
44
- )
46
+ func getRequest (conn * ss.Conn , auth bool ) (host string , res_size int , ota bool , err error ) {
47
+ var n int
48
+ ss .SetReadTimeout (conn )
45
49
46
50
// buf size should at least have the same size with the largest possible
47
51
// request size (when addrType is 3, domain name has at most 256 bytes)
48
- // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
49
- buf := make ([]byte , 260 )
50
- var n int
52
+ // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
53
+ buf := make ([]byte , 270 )
51
54
// read till we get possible domain length field
52
- ss .SetReadTimeout (conn )
53
- if n , err = io .ReadAtLeast (conn , buf , idDmLen + 1 ); err != nil {
55
+ if n , err = io .ReadFull (conn , buf [:idType + 1 ]); err != nil {
54
56
return
55
57
}
58
+ res_size += n
56
59
57
- reqLen := - 1
58
- switch buf [idType ] {
60
+ var reqStart , reqEnd int
61
+ addrType := buf [idType ]
62
+ switch addrType & ss .AddrMask {
59
63
case typeIPv4 :
60
- reqLen = lenIPv4
64
+ reqStart , reqEnd = idIP0 , idIP0 + lenIPv4
61
65
case typeIPv6 :
62
- reqLen = lenIPv6
66
+ reqStart , reqEnd = idIP0 , idIP0 + lenIPv6
63
67
case typeDm :
64
- reqLen = int (buf [idDmLen ]) + lenDmBase
68
+ if n , err = io .ReadFull (conn , buf [idType + 1 :idDmLen + 1 ]); err != nil {
69
+ return
70
+ }
71
+ reqStart , reqEnd = idDm0 , int (idDm0 + buf [idDmLen ]+ lenDmBase )
65
72
default :
66
- err = fmt .Errorf ("addr type %d not supported" , buf [ idType ] )
73
+ err = fmt .Errorf ("addr type %d not supported" , addrType & ss . AddrMask )
67
74
return
68
75
}
76
+ res_size += n
69
77
70
- if n < reqLen { // rare case
71
- if _ , err = io .ReadFull (conn , buf [n :reqLen ]); err != nil {
72
- return
73
- }
74
- } else if n > reqLen {
75
- // it's possible to read more than just the request head
76
- extra = buf [reqLen :n ]
78
+ if n , err = io .ReadFull (conn , buf [reqStart :reqEnd ]); err != nil {
79
+ return
77
80
}
81
+ res_size += n
78
82
79
83
// Return string for typeIP is not most efficient, but browsers (Chrome,
80
84
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
81
85
// big problem.
82
- switch buf [ idType ] {
86
+ switch addrType & ss . AddrMask {
83
87
case typeIPv4 :
84
88
host = net .IP (buf [idIP0 : idIP0 + net .IPv4len ]).String ()
85
89
case typeIPv6 :
@@ -88,8 +92,23 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
88
92
host = string (buf [idDm0 : idDm0 + buf [idDmLen ]])
89
93
}
90
94
// parse port
91
- port := binary .BigEndian .Uint16 (buf [reqLen - 2 : reqLen ])
95
+ port := binary .BigEndian .Uint16 (buf [reqEnd - 2 : reqEnd ])
92
96
host = net .JoinHostPort (host , strconv .Itoa (int (port )))
97
+ // if specified one time auth enabled, we should verify this
98
+ if auth || addrType & ss .OneTimeAuthMask > 0 {
99
+ ota = true
100
+ if n , err = io .ReadFull (conn , buf [reqEnd :reqEnd + lenHmacSha1 ]); err != nil {
101
+ return
102
+ }
103
+ iv := conn .GetIv ()
104
+ key := conn .GetKey ()
105
+ actualHmacSha1Buf := ss .HmacSha1 (append (iv , key ... ), buf [:reqEnd ])
106
+ if ! bytes .Equal (buf [reqEnd :reqEnd + lenHmacSha1 ], actualHmacSha1Buf ) {
107
+ err = fmt .Errorf ("verify one time auth failed, iv=%v key=%v data=%v" , iv , key , buf [:reqEnd ])
108
+ return
109
+ }
110
+ res_size += n
111
+ }
93
112
return
94
113
}
95
114
@@ -98,13 +117,9 @@ const logCntDelta = 100
98
117
var connCnt int
99
118
var nextLogConnCnt int = logCntDelta
100
119
101
- func handleConnection (user user.User , conn * ss.Conn ) {
120
+ func handleConnection (user user.User , conn * ss.Conn , auth bool ) {
102
121
var host string
103
- var size = 0
104
- var raw_req_header , raw_res_header []byte
105
- var is_http = false
106
- var res_size = 0
107
- var req_chan = make (chan []byte )
122
+
108
123
connCnt ++ // this maybe not accurate, but should be enough
109
124
if connCnt - nextLogConnCnt >= 0 {
110
125
// XXX There's no xadd in the atomic package, so it's difficult to log
@@ -128,7 +143,7 @@ func handleConnection(user user.User, conn *ss.Conn) {
128
143
}
129
144
}()
130
145
131
- host , extra , err := getRequest (conn )
146
+ host , res_size , ota , err := getRequest (conn , auth )
132
147
if err != nil {
133
148
Log .Error ("error getting request" , conn .RemoteAddr (), conn .LocalAddr (), err )
134
149
return
@@ -151,65 +166,30 @@ func handleConnection(user user.User, conn *ss.Conn) {
151
166
}
152
167
}()
153
168
154
- defer func () {
155
- if is_http {
156
- tmp_req_header := <- req_chan
157
- buffer := bytes .NewBuffer (raw_req_header )
158
- buffer .Write (tmp_req_header )
159
- raw_req_header = buffer .Bytes ()
160
- }
161
- showConn (raw_req_header , raw_res_header , host , user , size , is_http )
162
- close (req_chan )
163
- if ! closed {
164
- remote .Close ()
165
- }
166
- }()
169
+ // debug conn info
170
+ Log .Debug (fmt .Sprintf ("%d conn debug: local addr: %s | remote addr: %s network: %s " , user .GetPort (),
171
+ conn .LocalAddr ().String (), conn .RemoteAddr ().String (), conn .RemoteAddr ().Network ()))
172
+ err = storage .IncrSize (user , res_size )
173
+ if err != nil {
174
+ Log .Error (err )
175
+ return
176
+ }
177
+ err = storage .MarkUserOnline (user )
178
+ if err != nil {
179
+ Log .Error (err )
180
+ return
181
+ }
182
+ Log .Debug (fmt .Sprintf ("[port-%d] store size: %d" , user .GetPort (), res_size ))
167
183
168
- // write extra bytes read from
169
- if extra != nil {
170
- // debug.Println("getRequest read extra data, writing to remote, len", len(extra))
171
- is_http , extra , _ = checkHttp (extra , conn )
172
- if strings .HasSuffix (host , ":80" ) {
173
- is_http = true
174
- }
175
- raw_req_header = extra
176
- res_size , err = remote .Write (extra )
177
- // size, err := remote.Write(extra)
178
- if err != nil {
179
- Log .Error ("write request extra error:" , err )
180
- return
181
- }
182
- // debug conn info
183
- Log .Debug (fmt .Sprintf ("%d conn debug: local addr: %s | remote addr: %s network: %s " , user .GetPort (),
184
- conn .LocalAddr ().String (), conn .RemoteAddr ().String (), conn .RemoteAddr ().Network ()))
185
- err = storage .IncrSize (user , res_size )
186
- if err != nil {
187
- Log .Error (err )
188
- return
189
- }
190
- err = storage .MarkUserOnline (user )
191
- if err != nil {
192
- Log .Error (err )
193
- return
194
- }
195
- Log .Debug (fmt .Sprintf ("[port-%d] store size: %d" , user .GetPort (), res_size ))
184
+ Log .Info (fmt .Sprintf ("piping %s<->%s ota=%v connOta=%v" , conn .RemoteAddr (), host , ota , conn .IsOta ()))
185
+
186
+ if ota {
187
+ go PipeThenCloseOta (conn , remote , false , host , user )
188
+ } else {
189
+ go PipeThenClose (conn , remote , false , host , user )
196
190
}
197
- Log .Debug (fmt .Sprintf ("piping %s<->%s" , conn .RemoteAddr (), host ))
198
- /**
199
- go ss.PipeThenClose(conn, remote)
200
- ss.PipeThenClose(remote, conn)
201
- closed = true
202
- return
203
- **/
204
- go func () {
205
- _ , raw_header := PipeThenClose (conn , remote , is_http , false , host , user )
206
- if is_http {
207
- req_chan <- raw_header
208
- }
209
- }()
210
191
211
- res_size , raw_res_header = PipeThenClose (remote , conn , is_http , true , host , user )
212
- size += res_size
192
+ PipeThenClose (remote , conn , true , host , user )
213
193
closed = true
214
194
return
215
195
}
@@ -273,7 +253,7 @@ func runWithCustomMethod(user user.User) {
273
253
os .Exit (1 )
274
254
}
275
255
passwdManager .add (port , password , ln )
276
- cipher , err := user .GetCipher ()
256
+ cipher , err , auth := user .GetCipher ()
277
257
if err != nil {
278
258
return
279
259
}
@@ -288,27 +268,34 @@ func runWithCustomMethod(user user.User) {
288
268
// Creating cipher upon first connection.
289
269
if cipher == nil {
290
270
Log .Debug ("creating cipher for port:" , port )
291
- cipher , err = ss .NewCipher (user .GetMethod (), password )
271
+ method := user .GetMethod ()
272
+
273
+ if strings .HasSuffix (method , "-auth" ) {
274
+ method = method [:len (method )- 5 ]
275
+ auth = true
276
+ } else {
277
+ auth = false
278
+ }
279
+
280
+ cipher , err = ss .NewCipher (method , password )
292
281
if err != nil {
293
282
Log .Error (fmt .Sprintf ("Error generating cipher for port: %s %v\n " , port , err ))
294
283
conn .Close ()
295
284
continue
296
285
}
297
286
}
298
- go handleConnection (user , ss .NewConn (conn , cipher .Copy ()))
287
+ go handleConnection (user , ss .NewConn (conn , cipher .Copy ()), auth )
299
288
}
300
289
}
301
290
302
291
const bufSize = 4096
303
292
const nBuf = 2048
304
293
305
- func PipeThenClose (src , dst net.Conn , is_http bool , is_res bool , host string , user user.User ) ( total int , raw_header [] byte ) {
294
+ func PipeThenClose (src , dst net.Conn , is_res bool , host string , user user.User ) {
306
295
var pipeBuf = leakybuf .NewLeakyBuf (nBuf , bufSize )
307
296
defer dst .Close ()
308
297
buf := pipeBuf .Get ()
309
298
// defer pipeBuf.Put(buf)
310
- var buffer = bytes .NewBuffer (nil )
311
- var is_end = false
312
299
var size int
313
300
314
301
for {
@@ -317,15 +304,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
317
304
// read may return EOF with n > 0
318
305
// should always process n > 0 bytes before handling error
319
306
if n > 0 {
320
- if is_http && ! is_end {
321
- buffer .Write (buf )
322
- raw_header = buffer .Bytes ()
323
- lines := bytes .SplitN (raw_header , []byte ("\r \n \r \n " ), 2 )
324
- if len (lines ) == 2 {
325
- is_end = true
326
- }
327
- }
328
-
329
307
size , err = dst .Write (buf [0 :n ])
330
308
if is_res {
331
309
err = storage .IncrSize (user , size )
@@ -334,7 +312,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
334
312
}
335
313
Log .Debug (fmt .Sprintf ("[port-%d] store size: %d" , user .GetPort (), size ))
336
314
}
337
- total += size
338
315
if err != nil {
339
316
Log .Debug ("write:" , err )
340
317
break
@@ -350,6 +327,69 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
350
327
return
351
328
}
352
329
330
+ func PipeThenCloseOta (src * ss.Conn , dst net.Conn , is_res bool , host string , user user.User ) {
331
+ const (
332
+ dataLenLen = 2
333
+ hmacSha1Len = 10
334
+ idxData0 = dataLenLen + hmacSha1Len
335
+ )
336
+
337
+ defer func () {
338
+ dst .Close ()
339
+ }()
340
+ var pipeBuf = leakybuf .NewLeakyBuf (nBuf , bufSize )
341
+ buf := pipeBuf .Get ()
342
+ // sometimes it have to fill large block
343
+ for i := 1 ; ; i += 1 {
344
+ SetReadTimeout (src )
345
+ n , err := io .ReadFull (src , buf [:dataLenLen + hmacSha1Len ])
346
+ if err != nil {
347
+ if err == io .EOF {
348
+ break
349
+ }
350
+ Log .Debug (fmt .Sprintf ("conn=%p #%v read header error n=%v: %v" , src , i , n , err ))
351
+ break
352
+ }
353
+ dataLen := binary .BigEndian .Uint16 (buf [:dataLenLen ])
354
+ expectedHmacSha1 := buf [dataLenLen :idxData0 ]
355
+
356
+ var dataBuf []byte
357
+ if len (buf ) < int (idxData0 + dataLen ) {
358
+ dataBuf = make ([]byte , dataLen )
359
+ } else {
360
+ dataBuf = buf [idxData0 : idxData0 + dataLen ]
361
+ }
362
+ if n , err := io .ReadFull (src , dataBuf ); err != nil {
363
+ if err == io .EOF {
364
+ break
365
+ }
366
+ Log .Debug (fmt .Sprintf ("conn=%p #%v read data error n=%v: %v" , src , i , n , err ))
367
+ break
368
+ }
369
+ chunkIdBytes := make ([]byte , 4 )
370
+ chunkId := src .GetAndIncrChunkId ()
371
+ binary .BigEndian .PutUint32 (chunkIdBytes , chunkId )
372
+ actualHmacSha1 := ss .HmacSha1 (append (src .GetIv (), chunkIdBytes ... ), dataBuf )
373
+ if ! bytes .Equal (expectedHmacSha1 , actualHmacSha1 ) {
374
+ Log .Debug (fmt .Sprintf ("conn=%p #%v read data hmac-sha1 mismatch, iv=%v chunkId=%v src=%v dst=%v len=%v expeced=%v actual=%v" , src , i , src .GetIv (), chunkId , src .RemoteAddr (), dst .RemoteAddr (), dataLen , expectedHmacSha1 , actualHmacSha1 ))
375
+ break
376
+ }
377
+
378
+ if n , err := dst .Write (dataBuf ); err != nil {
379
+ Log .Debug (fmt .Sprintf ("conn=%p #%v write data error n=%v: %v" , dst , i , n , err ))
380
+ break
381
+ }
382
+ if is_res {
383
+ err := storage .IncrSize (user , n )
384
+ if err != nil {
385
+ Log .Error (err )
386
+ }
387
+ Log .Debug (fmt .Sprintf ("[port-%d] store size: %d" , user .GetPort (), n ))
388
+ }
389
+ }
390
+ return
391
+ }
392
+
353
393
var readTimeout time.Duration
354
394
355
395
func SetReadTimeout (c net.Conn ) {
0 commit comments