1
- //go:build linux
2
-
3
1
package main
4
2
5
3
import (
6
4
"bufio"
5
+ "context"
7
6
"encoding/binary"
8
7
"errors"
9
8
"flag"
10
9
"fmt"
11
- "golang.org/x/crypto/cryptobyte"
12
10
"io"
13
11
"log"
14
12
"log/slog"
15
13
"net"
16
14
"net/http"
17
15
"net/url"
18
16
"os"
17
+ "os/signal"
19
18
"strings"
20
19
"sync"
21
20
"sync/atomic"
22
21
"syscall"
23
- "unsafe"
22
+
23
+ "golang.org/x/crypto/cryptobyte"
24
24
)
25
25
26
- var version = "0.2.0 "
26
+ var version = "0.2.1 "
27
27
28
+ // PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection.
29
+ // Any Read before the EndPreread can be undone and read again by calling the EndPreread function.
28
30
type PrereadConn struct {
29
- prereadStarted bool
30
- prereadEnded bool
31
- prereadBuf []byte
32
- mu sync.Mutex
33
- conn net.Conn
34
- }
35
-
36
- func (c * PrereadConn ) StartPreread () {
37
- c .mu .Lock ()
38
- defer c .mu .Unlock ()
39
- if c .prereadStarted {
40
- panic ("call StartPreread after preread has already started or ended" )
41
- }
42
- c .prereadStarted = true
31
+ ended bool
32
+ buf []byte
33
+ mu sync.Mutex
34
+ conn net.Conn
43
35
}
44
36
45
- func (c * PrereadConn ) RestorePreread () {
37
+ // EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again.
38
+ // EndPreread can be only called once.
39
+ func (c * PrereadConn ) EndPreread () {
46
40
c .mu .Lock ()
47
41
defer c .mu .Unlock ()
48
- if ! c . prereadStarted || c . prereadEnded {
49
- panic ("call RestorePreread after preread has ended or hasn't started" )
42
+ if c . ended {
43
+ panic ("call EndPreread after preread has ended or hasn't started" )
50
44
}
51
- c .prereadEnded = true
45
+ c .ended = true
52
46
}
53
47
48
+ // Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread.
54
49
func (c * PrereadConn ) Read (p []byte ) (n int , err error ) {
55
50
c .mu .Lock ()
56
51
defer c .mu .Unlock ()
57
- if c .prereadEnded {
58
- n = copy (p , c .prereadBuf )
59
- bufLen := len (c .prereadBuf )
60
- c .prereadBuf = c .prereadBuf [n :]
52
+ if c .ended {
53
+ n = copy (p , c .buf )
54
+ bufLen := len (c .buf )
55
+ c .buf = c .buf [n :]
61
56
if n == len (p ) || (bufLen > 0 && bufLen == n ) {
62
57
return n , nil
63
58
}
64
59
rn , err := c .conn .Read (p [n :])
65
60
return rn + n , err
66
- }
67
- if c .prereadStarted {
61
+ } else {
68
62
n , err = c .conn .Read (p )
69
- c .prereadBuf = append (c .prereadBuf , p [:n ]... )
63
+ c .buf = append (c .buf , p [:n ]... )
70
64
return n , err
71
65
}
72
- return c .conn .Read (p )
73
66
}
74
67
68
+ // Write writes data to the underlying connection.
75
69
func (c * PrereadConn ) Write (p []byte ) (n int , err error ) {
76
70
return c .conn .Write (p )
77
71
}
78
72
73
+ // NewPrereadConn wraps the network connection and return a *PrereadConn.
74
+ // It's recommended to not touch the original connection after wrapped.
79
75
func NewPrereadConn (conn net.Conn ) * PrereadConn {
80
76
return & PrereadConn {conn : conn }
81
77
}
82
78
79
+ // PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection.
83
80
func PrereadSNI (conn * PrereadConn ) (_ string , err error ) {
84
- conn .StartPreread ()
85
- defer conn .RestorePreread ()
81
+ defer conn .EndPreread ()
86
82
defer func () {
87
83
if err != nil {
88
84
err = fmt .Errorf ("failed to preread TLS client hello: %w" , err )
@@ -114,11 +110,11 @@ func PrereadSNI(conn *PrereadConn) (_ string, err error) {
114
110
115
111
func extractSNI (data []byte ) (string , error ) {
116
112
s := cryptobyte .String (data )
117
- var version uint16
118
- var random [] byte
119
- var sessionId []byte
120
- var compressionMethods []byte
121
- var cipherSuites [] uint16
113
+ var (
114
+ version uint16
115
+ random []byte
116
+ sessionId []byte
117
+ )
122
118
123
119
if ! s .Skip (9 ) ||
124
120
! s .ReadUint16 (& version ) || ! s .ReadBytes (& random , 32 ) ||
@@ -130,6 +126,8 @@ func extractSNI(data []byte) (string, error) {
130
126
if ! s .ReadUint16LengthPrefixed (& cipherSuitesData ) {
131
127
return "" , fmt .Errorf ("failed to parse TLS client hello cipher suites" )
132
128
}
129
+
130
+ var cipherSuites []uint16
133
131
for ! cipherSuitesData .Empty () {
134
132
var suite uint16
135
133
if ! cipherSuitesData .ReadUint16 (& suite ) {
@@ -138,6 +136,7 @@ func extractSNI(data []byte) (string, error) {
138
136
cipherSuites = append (cipherSuites , suite )
139
137
}
140
138
139
+ var compressionMethods []byte
141
140
if ! s .ReadUint8LengthPrefixed ((* cryptobyte .String )(& compressionMethods )) {
142
141
return "" , fmt .Errorf ("failed to parse TLS client hello compression methods" )
143
142
}
@@ -191,15 +190,15 @@ func extractSNI(data []byte) (string, error) {
191
190
return finalServerName , nil
192
191
}
193
192
193
+ // PrereadHttpHost pre-reads the HTTP Host header from an HTTP connection.
194
194
func PrereadHttpHost (conn * PrereadConn ) (_ string , err error ) {
195
195
defer func () {
196
196
if err != nil {
197
197
err = fmt .Errorf ("failed to preread HTTP request: %w" , err )
198
198
}
199
199
}()
200
200
201
- conn .StartPreread ()
202
- defer conn .RestorePreread ()
201
+ defer conn .EndPreread ()
203
202
req , err := http .ReadRequest (bufio .NewReader (conn ))
204
203
if err != nil {
205
204
return "" , err
@@ -211,6 +210,7 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) {
211
210
return host , nil
212
211
}
213
212
213
+ // DialProxy dials the TCP connection to the proxy.
214
214
func DialProxy (proxy string ) (net.Conn , error ) {
215
215
proxyAddr , err := net .ResolveTCPAddr ("tcp" , proxy )
216
216
if err != nil {
@@ -223,6 +223,7 @@ func DialProxy(proxy string) (net.Conn, error) {
223
223
return conn , nil
224
224
}
225
225
226
+ // DialProxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy.
226
227
func DialProxyConnect (proxy string , dst string ) (net.Conn , error ) {
227
228
conn , err := DialProxy (proxy )
228
229
if err != nil {
@@ -255,6 +256,7 @@ func DialProxyConnect(proxy string, dst string) (net.Conn, error) {
255
256
return conn , nil
256
257
}
257
258
259
+ // GetOriginalDst get the original destination address of a TCP connection before dstnat.
258
260
func GetOriginalDst (conn * net.TCPConn ) (* net.TCPAddr , error ) {
259
261
file , err := conn .File ()
260
262
defer func (file * os.File ) {
@@ -266,27 +268,15 @@ func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) {
266
268
if err != nil {
267
269
return nil , fmt .Errorf ("failed to convert connection to file: %w" , err )
268
270
}
269
- var sockaddr [16 ]byte
270
- size := 16
271
- _ , _ , e := syscall .Syscall6 (
272
- syscall .SYS_GETSOCKOPT ,
273
- file .Fd (),
271
+ return GetsockoptIPv4OriginalDst (
272
+ int (file .Fd ()),
274
273
syscall .SOL_IP ,
275
274
80 , // SO_ORIGINAL_DST
276
- uintptr (unsafe .Pointer (& sockaddr )),
277
- uintptr (unsafe .Pointer (& size )),
278
- 0 ,
279
275
)
280
- if e != 0 {
281
- return nil , fmt .Errorf ("getsockopt SO_ORIGINAL_DST failed: errno %d" , e )
282
- }
283
- return & net.TCPAddr {
284
- IP : sockaddr [4 :8 ],
285
- Port : int (binary .BigEndian .Uint16 (sockaddr [2 :4 ])),
286
- }, nil
287
276
}
288
277
289
- func RelayTcp (conn io.ReadWriter , proxyConn io.ReadWriteCloser , logger * slog.Logger ) {
278
+ // RelayTCP relays data between the incoming TCP connection and the proxy connection.
279
+ func RelayTCP (conn io.ReadWriter , proxyConn io.ReadWriteCloser , logger * slog.Logger ) {
290
280
var closed atomic.Bool
291
281
go func () {
292
282
_ , err := io .Copy (proxyConn , conn )
@@ -303,10 +293,10 @@ func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Log
303
293
closed .Store (true )
304
294
}
305
295
306
- func RelayHttp ( conn io. ReadWriter , proxyConn io. ReadWriteCloser , logger * slog. Logger ) {
307
- defer func () {
308
- _ = proxyConn . Close ()
309
- } ()
296
+ // RelayHTTP relays a single HTTP request and response between a local connection and a proxy.
297
+ // It modifies the Connection header to "close" in both the request and response.
298
+ func RelayHTTP ( conn io. ReadWriter , proxyConn io. ReadWriteCloser , logger * slog. Logger ) {
299
+ defer proxyConn . Close ()
310
300
req , err := http .ReadRequest (bufio .NewReader (conn ))
311
301
if err != nil {
312
302
logger .Error ("failed to read HTTP request from connection" , "error" , err )
@@ -331,59 +321,65 @@ func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Lo
331
321
}
332
322
}
333
323
324
+ // HandleConn manages the incoming connections.
334
325
func HandleConn (conn net.Conn , proxy string ) {
335
- defer func () { _ = conn .Close () } ()
326
+ defer conn .Close ()
336
327
logger := slog .With ("src" , conn .RemoteAddr ())
337
328
dst , err := GetOriginalDst (conn .(* net.TCPConn ))
338
329
if err != nil {
339
330
slog .Error ("failed to get connection original destination" , "error" , err )
340
331
return
341
332
}
342
333
logger = logger .With ("original_dst" , dst )
343
- var host string
344
- var relay func (conn io.ReadWriter , proxyConn io.ReadWriteCloser , logger * slog.Logger )
345
- var dialProxy func (proxy string ) (net.Conn , error )
346
334
consigned := NewPrereadConn (conn )
347
335
switch dst .Port {
348
336
case 443 :
349
- relay = RelayTcp
350
- sni , sniErr := PrereadSNI ( consigned )
351
- if sniErr != nil {
352
- err = sniErr
337
+ sni , err := PrereadSNI ( consigned )
338
+ if err != nil {
339
+ logger . Error ( "failed to preread SNI from connection" , "error" , err )
340
+ return
353
341
} else {
354
- host = fmt .Sprintf ("%s:%d" , sni , dst .Port )
355
- dialProxy = func (proxy string ) (net.Conn , error ) { return DialProxyConnect (proxy , host ) }
342
+ host := fmt .Sprintf ("%s:%d" , sni , dst .Port )
343
+ logger = logger .With ("host" , host )
344
+ proxyConn , err := DialProxyConnect (proxy , host )
345
+ if err != nil {
346
+ logger .Error ("failed to connect to http proxy" , "error" , err )
347
+ return
348
+ }
349
+ logger .Info ("relay TLS connection to proxy" )
350
+ RelayTCP (consigned , proxyConn , logger )
356
351
}
357
352
case 80 :
358
- relay = RelayHttp
359
- host , err = PrereadHttpHost (consigned )
353
+ host , err := PrereadHttpHost (consigned )
354
+ if err != nil {
355
+ logger .Error ("failed to preread HTTP host from connection" , "error" , err )
356
+ return
357
+ }
360
358
if ! strings .Contains (host , ":" ) {
361
359
host = fmt .Sprintf ("%s:%d" , host , dst .Port )
362
360
}
363
- dialProxy = DialProxy
361
+ logger = logger .With ("host" , host )
362
+ proxyConn , err := DialProxy (proxy )
363
+ if err != nil {
364
+ logger .Error ("failed to connect to http proxy" , "error" , err )
365
+ return
366
+ }
367
+ logger .Info ("relay HTTP connection to proxy" )
368
+ RelayHTTP (consigned , proxyConn , logger )
364
369
default :
365
370
logger .Error (fmt .Sprintf ("unknown destination port: %d" , dst .Port ))
366
371
return
367
372
}
368
- if err != nil {
369
- logger .Error ("failed to preread host from connection" , "error" , err )
370
- return
371
- }
372
- logger = logger .With ("host" , host )
373
- proxyConn , err := dialProxy (proxy )
374
- if err != nil {
375
- logger .Error ("failed to connect to http proxy" , "error" , err )
376
- return
377
- }
378
- logger .Info ("relay connection to http proxy" )
379
- relay (consigned , proxyConn , logger )
380
373
}
381
374
382
375
func main () {
383
376
proxyFlag := flag .String ("proxy" , "" , "upstream HTTP proxy address in the 'host:port' format" )
384
377
flag .Parse ()
385
378
listenAddr := & net.TCPAddr {IP : net .ParseIP ("0.0.0.0" ), Port : 8443 }
386
- listener , err := net .ListenTCP ("tcp" , listenAddr )
379
+ ctx := context .Background ()
380
+ signal .NotifyContext (ctx , os .Interrupt )
381
+ listenConfig := new (net.ListenConfig )
382
+ listener , err := listenConfig .Listen (ctx , "tcp" , fmt .Sprintf (":8443" ))
387
383
if err != nil {
388
384
log .Fatalf ("failed to listen on %s" , listenAddr .String ())
389
385
}
0 commit comments