Skip to content
This repository was archived by the owner on Jun 16, 2019. It is now read-only.

Commit 64ced42

Browse files
committed
Merge pull request #48 from jemyzhang/master
try to enable ota
2 parents 6be75ab + ca3214f commit 64ced42

File tree

4 files changed

+177
-119
lines changed

4 files changed

+177
-119
lines changed

mu/func.go

Lines changed: 154 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -25,61 +25,65 @@ import (
2525
"time"
2626
)
2727

28+
const (
29+
idType = 0 // address type index
30+
idIP0 = 1 // ip addres start index
31+
idDmLen = 1 // domain address length index
32+
idDm0 = 2 // domain address start index
33+
34+
typeIPv4 = 1 // type is ipv4 address
35+
typeDm = 3 // type is domain address
36+
typeIPv6 = 4 // type is ipv6 address
37+
38+
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
39+
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
40+
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
41+
lenHmacSha1 = 10
42+
)
43+
2844
var ssdebug ss.DebugLog
2945

30-
func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
31-
const (
32-
idType = 0 // address type index
33-
idIP0 = 1 // ip addres start index
34-
idDmLen = 1 // domain address length index
35-
idDm0 = 2 // domain address start index
36-
37-
typeIPv4 = 1 // type is ipv4 address
38-
typeDm = 3 // type is domain address
39-
typeIPv6 = 4 // type is ipv6 address
40-
41-
lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port
42-
lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port
43-
lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
44-
)
46+
func getRequest(conn *ss.Conn, auth bool) (host string, res_size int, ota bool, err error) {
47+
var n int
48+
ss.SetReadTimeout(conn)
4549

4650
// buf size should at least have the same size with the largest possible
4751
// request size (when addrType is 3, domain name has at most 256 bytes)
48-
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
49-
buf := make([]byte, 260)
50-
var n int
52+
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
53+
buf := make([]byte, 270)
5154
// read till we get possible domain length field
52-
ss.SetReadTimeout(conn)
53-
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
55+
if n, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
5456
return
5557
}
58+
res_size += n
5659

57-
reqLen := -1
58-
switch buf[idType] {
60+
var reqStart, reqEnd int
61+
addrType := buf[idType]
62+
switch addrType & ss.AddrMask {
5963
case typeIPv4:
60-
reqLen = lenIPv4
64+
reqStart, reqEnd = idIP0, idIP0+lenIPv4
6165
case typeIPv6:
62-
reqLen = lenIPv6
66+
reqStart, reqEnd = idIP0, idIP0+lenIPv6
6367
case typeDm:
64-
reqLen = int(buf[idDmLen]) + lenDmBase
68+
if n, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
69+
return
70+
}
71+
reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase)
6572
default:
66-
err = fmt.Errorf("addr type %d not supported", buf[idType])
73+
err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask)
6774
return
6875
}
76+
res_size += n
6977

70-
if n < reqLen { // rare case
71-
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
72-
return
73-
}
74-
} else if n > reqLen {
75-
// it's possible to read more than just the request head
76-
extra = buf[reqLen:n]
78+
if n, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
79+
return
7780
}
81+
res_size += n
7882

7983
// Return string for typeIP is not most efficient, but browsers (Chrome,
8084
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
8185
// big problem.
82-
switch buf[idType] {
86+
switch addrType & ss.AddrMask {
8387
case typeIPv4:
8488
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
8589
case typeIPv6:
@@ -88,8 +92,23 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
8892
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
8993
}
9094
// parse port
91-
port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen])
95+
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
9296
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
97+
// if specified one time auth enabled, we should verify this
98+
if auth || addrType&ss.OneTimeAuthMask > 0 {
99+
ota = true
100+
if n, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil {
101+
return
102+
}
103+
iv := conn.GetIv()
104+
key := conn.GetKey()
105+
actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd])
106+
if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) {
107+
err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd])
108+
return
109+
}
110+
res_size += n
111+
}
93112
return
94113
}
95114

@@ -98,13 +117,9 @@ const logCntDelta = 100
98117
var connCnt int
99118
var nextLogConnCnt int = logCntDelta
100119

101-
func handleConnection(user user.User, conn *ss.Conn) {
120+
func handleConnection(user user.User, conn *ss.Conn, auth bool) {
102121
var host string
103-
var size = 0
104-
var raw_req_header, raw_res_header []byte
105-
var is_http = false
106-
var res_size = 0
107-
var req_chan = make(chan []byte)
122+
108123
connCnt++ // this maybe not accurate, but should be enough
109124
if connCnt-nextLogConnCnt >= 0 {
110125
// XXX There's no xadd in the atomic package, so it's difficult to log
@@ -128,7 +143,7 @@ func handleConnection(user user.User, conn *ss.Conn) {
128143
}
129144
}()
130145

131-
host, extra, err := getRequest(conn)
146+
host, res_size, ota, err := getRequest(conn, auth)
132147
if err != nil {
133148
Log.Error("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err)
134149
return
@@ -151,65 +166,30 @@ func handleConnection(user user.User, conn *ss.Conn) {
151166
}
152167
}()
153168

154-
defer func() {
155-
if is_http {
156-
tmp_req_header := <-req_chan
157-
buffer := bytes.NewBuffer(raw_req_header)
158-
buffer.Write(tmp_req_header)
159-
raw_req_header = buffer.Bytes()
160-
}
161-
showConn(raw_req_header, raw_res_header, host, user, size, is_http)
162-
close(req_chan)
163-
if !closed {
164-
remote.Close()
165-
}
166-
}()
169+
// debug conn info
170+
Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(),
171+
conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network()))
172+
err = storage.IncrSize(user, res_size)
173+
if err != nil {
174+
Log.Error(err)
175+
return
176+
}
177+
err = storage.MarkUserOnline(user)
178+
if err != nil {
179+
Log.Error(err)
180+
return
181+
}
182+
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size))
167183

168-
// write extra bytes read from
169-
if extra != nil {
170-
// debug.Println("getRequest read extra data, writing to remote, len", len(extra))
171-
is_http, extra, _ = checkHttp(extra, conn)
172-
if strings.HasSuffix(host, ":80") {
173-
is_http = true
174-
}
175-
raw_req_header = extra
176-
res_size, err = remote.Write(extra)
177-
// size, err := remote.Write(extra)
178-
if err != nil {
179-
Log.Error("write request extra error:", err)
180-
return
181-
}
182-
// debug conn info
183-
Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(),
184-
conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network()))
185-
err = storage.IncrSize(user, res_size)
186-
if err != nil {
187-
Log.Error(err)
188-
return
189-
}
190-
err = storage.MarkUserOnline(user)
191-
if err != nil {
192-
Log.Error(err)
193-
return
194-
}
195-
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size))
184+
Log.Info(fmt.Sprintf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta()))
185+
186+
if ota {
187+
go PipeThenCloseOta(conn, remote, false, host, user)
188+
} else {
189+
go PipeThenClose(conn, remote, false, host, user)
196190
}
197-
Log.Debug(fmt.Sprintf("piping %s<->%s", conn.RemoteAddr(), host))
198-
/**
199-
go ss.PipeThenClose(conn, remote)
200-
ss.PipeThenClose(remote, conn)
201-
closed = true
202-
return
203-
**/
204-
go func() {
205-
_, raw_header := PipeThenClose(conn, remote, is_http, false, host, user)
206-
if is_http {
207-
req_chan <- raw_header
208-
}
209-
}()
210191

211-
res_size, raw_res_header = PipeThenClose(remote, conn, is_http, true, host, user)
212-
size += res_size
192+
PipeThenClose(remote, conn, true, host, user)
213193
closed = true
214194
return
215195
}
@@ -273,7 +253,7 @@ func runWithCustomMethod(user user.User) {
273253
os.Exit(1)
274254
}
275255
passwdManager.add(port, password, ln)
276-
cipher, err := user.GetCipher()
256+
cipher, err, auth := user.GetCipher()
277257
if err != nil {
278258
return
279259
}
@@ -288,27 +268,34 @@ func runWithCustomMethod(user user.User) {
288268
// Creating cipher upon first connection.
289269
if cipher == nil {
290270
Log.Debug("creating cipher for port:", port)
291-
cipher, err = ss.NewCipher(user.GetMethod(), password)
271+
method := user.GetMethod()
272+
273+
if strings.HasSuffix(method, "-auth") {
274+
method = method[:len(method)-5]
275+
auth = true
276+
} else {
277+
auth = false
278+
}
279+
280+
cipher, err = ss.NewCipher(method, password)
292281
if err != nil {
293282
Log.Error(fmt.Sprintf("Error generating cipher for port: %s %v\n", port, err))
294283
conn.Close()
295284
continue
296285
}
297286
}
298-
go handleConnection(user, ss.NewConn(conn, cipher.Copy()))
287+
go handleConnection(user, ss.NewConn(conn, cipher.Copy()), auth)
299288
}
300289
}
301290

302291
const bufSize = 4096
303292
const nBuf = 2048
304293

305-
func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, user user.User) (total int, raw_header []byte) {
294+
func PipeThenClose(src, dst net.Conn, is_res bool, host string, user user.User) {
306295
var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize)
307296
defer dst.Close()
308297
buf := pipeBuf.Get()
309298
// defer pipeBuf.Put(buf)
310-
var buffer = bytes.NewBuffer(nil)
311-
var is_end = false
312299
var size int
313300

314301
for {
@@ -317,15 +304,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
317304
// read may return EOF with n > 0
318305
// should always process n > 0 bytes before handling error
319306
if n > 0 {
320-
if is_http && !is_end {
321-
buffer.Write(buf)
322-
raw_header = buffer.Bytes()
323-
lines := bytes.SplitN(raw_header, []byte("\r\n\r\n"), 2)
324-
if len(lines) == 2 {
325-
is_end = true
326-
}
327-
}
328-
329307
size, err = dst.Write(buf[0:n])
330308
if is_res {
331309
err = storage.IncrSize(user, size)
@@ -334,7 +312,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
334312
}
335313
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), size))
336314
}
337-
total += size
338315
if err != nil {
339316
Log.Debug("write:", err)
340317
break
@@ -350,6 +327,69 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
350327
return
351328
}
352329

330+
func PipeThenCloseOta(src *ss.Conn, dst net.Conn, is_res bool, host string, user user.User) {
331+
const (
332+
dataLenLen = 2
333+
hmacSha1Len = 10
334+
idxData0 = dataLenLen + hmacSha1Len
335+
)
336+
337+
defer func() {
338+
dst.Close()
339+
}()
340+
var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize)
341+
buf := pipeBuf.Get()
342+
// sometimes it have to fill large block
343+
for i := 1; ; i += 1 {
344+
SetReadTimeout(src)
345+
n, err := io.ReadFull(src, buf[:dataLenLen+hmacSha1Len])
346+
if err != nil {
347+
if err == io.EOF {
348+
break
349+
}
350+
Log.Debug(fmt.Sprintf("conn=%p #%v read header error n=%v: %v", src, i, n, err))
351+
break
352+
}
353+
dataLen := binary.BigEndian.Uint16(buf[:dataLenLen])
354+
expectedHmacSha1 := buf[dataLenLen:idxData0]
355+
356+
var dataBuf []byte
357+
if len(buf) < int(idxData0+dataLen) {
358+
dataBuf = make([]byte, dataLen)
359+
} else {
360+
dataBuf = buf[idxData0 : idxData0+dataLen]
361+
}
362+
if n, err := io.ReadFull(src, dataBuf); err != nil {
363+
if err == io.EOF {
364+
break
365+
}
366+
Log.Debug(fmt.Sprintf("conn=%p #%v read data error n=%v: %v", src, i, n, err))
367+
break
368+
}
369+
chunkIdBytes := make([]byte, 4)
370+
chunkId := src.GetAndIncrChunkId()
371+
binary.BigEndian.PutUint32(chunkIdBytes, chunkId)
372+
actualHmacSha1 := ss.HmacSha1(append(src.GetIv(), chunkIdBytes...), dataBuf)
373+
if !bytes.Equal(expectedHmacSha1, actualHmacSha1) {
374+
Log.Debug(fmt.Sprintf("conn=%p #%v read data hmac-sha1 mismatch, iv=%v chunkId=%v src=%v dst=%v len=%v expeced=%v actual=%v", src, i, src.GetIv(), chunkId, src.RemoteAddr(), dst.RemoteAddr(), dataLen, expectedHmacSha1, actualHmacSha1))
375+
break
376+
}
377+
378+
if n, err := dst.Write(dataBuf); err != nil {
379+
Log.Debug(fmt.Sprintf("conn=%p #%v write data error n=%v: %v", dst, i, n, err))
380+
break
381+
}
382+
if is_res {
383+
err := storage.IncrSize(user, n)
384+
if err != nil {
385+
Log.Error(err)
386+
}
387+
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), n))
388+
}
389+
}
390+
return
391+
}
392+
353393
var readTimeout time.Duration
354394

355395
func SetReadTimeout(c net.Conn) {

0 commit comments

Comments
 (0)