Skip to content

Commit 96816de

Browse files
authored
Apply suggestions from Ben Hoyt's review (#5)
1 parent 178f291 commit 96816de

File tree

7 files changed

+136
-90
lines changed

7 files changed

+136
-90
lines changed

.github/workflows/publish.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313

1414
publish:
1515
name: Publish Aproxy
16-
runs-on: [ self-hosted, linux, x64, large ]
16+
runs-on: ubuntu-latest
1717
needs: [ tests, integration-tests ]
1818

1919
steps:

.github/workflows/tests.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ jobs:
1717
with:
1818
go-version: 1.21
1919

20+
- name: Ensure No Formatting Changes
21+
run: |
22+
go fmt ./...
23+
git diff --exit-code
24+
2025
- name: Build and Test
2126
run: |
2227
go test -race ./...

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,17 @@ You can inspect the access logs of aproxy using:
4343
```bash
4444
sudo snap logs aproxy.aproxy -n=all
4545
```
46+
47+
## Running from Source
48+
49+
To run this application directly from the source code, you'll need to have Go
50+
1.21 installed on your system.
51+
52+
Follow these steps to get started:
53+
54+
```bash
55+
git clone https://github.com/canonical/aproxy.git
56+
cd aproxy
57+
go mod download
58+
go run . --proxy=squid.internal:3128
59+
```

aproxy.go

Lines changed: 82 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,84 @@
1-
//go:build linux
2-
31
package main
42

53
import (
64
"bufio"
5+
"context"
76
"encoding/binary"
87
"errors"
98
"flag"
109
"fmt"
11-
"golang.org/x/crypto/cryptobyte"
1210
"io"
1311
"log"
1412
"log/slog"
1513
"net"
1614
"net/http"
1715
"net/url"
1816
"os"
17+
"os/signal"
1918
"strings"
2019
"sync"
2120
"sync/atomic"
2221
"syscall"
23-
"unsafe"
22+
23+
"golang.org/x/crypto/cryptobyte"
2424
)
2525

26-
var version = "0.2.0"
26+
var version = "0.2.1"
2727

28+
// PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection.
29+
// Any Read before the EndPreread can be undone and read again by calling the EndPreread function.
2830
type PrereadConn struct {
29-
prereadStarted bool
30-
prereadEnded bool
31-
prereadBuf []byte
32-
mu sync.Mutex
33-
conn net.Conn
34-
}
35-
36-
func (c *PrereadConn) StartPreread() {
37-
c.mu.Lock()
38-
defer c.mu.Unlock()
39-
if c.prereadStarted {
40-
panic("call StartPreread after preread has already started or ended")
41-
}
42-
c.prereadStarted = true
31+
ended bool
32+
buf []byte
33+
mu sync.Mutex
34+
conn net.Conn
4335
}
4436

45-
func (c *PrereadConn) RestorePreread() {
37+
// EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again.
38+
// EndPreread can be only called once.
39+
func (c *PrereadConn) EndPreread() {
4640
c.mu.Lock()
4741
defer c.mu.Unlock()
48-
if !c.prereadStarted || c.prereadEnded {
49-
panic("call RestorePreread after preread has ended or hasn't started")
42+
if c.ended {
43+
panic("call EndPreread after preread has ended or hasn't started")
5044
}
51-
c.prereadEnded = true
45+
c.ended = true
5246
}
5347

48+
// Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread.
5449
func (c *PrereadConn) Read(p []byte) (n int, err error) {
5550
c.mu.Lock()
5651
defer c.mu.Unlock()
57-
if c.prereadEnded {
58-
n = copy(p, c.prereadBuf)
59-
bufLen := len(c.prereadBuf)
60-
c.prereadBuf = c.prereadBuf[n:]
52+
if c.ended {
53+
n = copy(p, c.buf)
54+
bufLen := len(c.buf)
55+
c.buf = c.buf[n:]
6156
if n == len(p) || (bufLen > 0 && bufLen == n) {
6257
return n, nil
6358
}
6459
rn, err := c.conn.Read(p[n:])
6560
return rn + n, err
66-
}
67-
if c.prereadStarted {
61+
} else {
6862
n, err = c.conn.Read(p)
69-
c.prereadBuf = append(c.prereadBuf, p[:n]...)
63+
c.buf = append(c.buf, p[:n]...)
7064
return n, err
7165
}
72-
return c.conn.Read(p)
7366
}
7467

68+
// Write writes data to the underlying connection.
7569
func (c *PrereadConn) Write(p []byte) (n int, err error) {
7670
return c.conn.Write(p)
7771
}
7872

73+
// NewPrereadConn wraps the network connection and return a *PrereadConn.
74+
// It's recommended to not touch the original connection after wrapped.
7975
func NewPrereadConn(conn net.Conn) *PrereadConn {
8076
return &PrereadConn{conn: conn}
8177
}
8278

79+
// PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection.
8380
func PrereadSNI(conn *PrereadConn) (_ string, err error) {
84-
conn.StartPreread()
85-
defer conn.RestorePreread()
81+
defer conn.EndPreread()
8682
defer func() {
8783
if err != nil {
8884
err = fmt.Errorf("failed to preread TLS client hello: %w", err)
@@ -114,11 +110,11 @@ func PrereadSNI(conn *PrereadConn) (_ string, err error) {
114110

115111
func extractSNI(data []byte) (string, error) {
116112
s := cryptobyte.String(data)
117-
var version uint16
118-
var random []byte
119-
var sessionId []byte
120-
var compressionMethods []byte
121-
var cipherSuites []uint16
113+
var (
114+
version uint16
115+
random []byte
116+
sessionId []byte
117+
)
122118

123119
if !s.Skip(9) ||
124120
!s.ReadUint16(&version) || !s.ReadBytes(&random, 32) ||
@@ -130,6 +126,8 @@ func extractSNI(data []byte) (string, error) {
130126
if !s.ReadUint16LengthPrefixed(&cipherSuitesData) {
131127
return "", fmt.Errorf("failed to parse TLS client hello cipher suites")
132128
}
129+
130+
var cipherSuites []uint16
133131
for !cipherSuitesData.Empty() {
134132
var suite uint16
135133
if !cipherSuitesData.ReadUint16(&suite) {
@@ -138,6 +136,7 @@ func extractSNI(data []byte) (string, error) {
138136
cipherSuites = append(cipherSuites, suite)
139137
}
140138

139+
var compressionMethods []byte
141140
if !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&compressionMethods)) {
142141
return "", fmt.Errorf("failed to parse TLS client hello compression methods")
143142
}
@@ -191,15 +190,15 @@ func extractSNI(data []byte) (string, error) {
191190
return finalServerName, nil
192191
}
193192

193+
// PrereadHttpHost pre-reads the HTTP Host header from an HTTP connection.
194194
func PrereadHttpHost(conn *PrereadConn) (_ string, err error) {
195195
defer func() {
196196
if err != nil {
197197
err = fmt.Errorf("failed to preread HTTP request: %w", err)
198198
}
199199
}()
200200

201-
conn.StartPreread()
202-
defer conn.RestorePreread()
201+
defer conn.EndPreread()
203202
req, err := http.ReadRequest(bufio.NewReader(conn))
204203
if err != nil {
205204
return "", err
@@ -211,6 +210,7 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) {
211210
return host, nil
212211
}
213212

213+
// DialProxy dials the TCP connection to the proxy.
214214
func DialProxy(proxy string) (net.Conn, error) {
215215
proxyAddr, err := net.ResolveTCPAddr("tcp", proxy)
216216
if err != nil {
@@ -223,6 +223,7 @@ func DialProxy(proxy string) (net.Conn, error) {
223223
return conn, nil
224224
}
225225

226+
// DialProxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy.
226227
func DialProxyConnect(proxy string, dst string) (net.Conn, error) {
227228
conn, err := DialProxy(proxy)
228229
if err != nil {
@@ -255,6 +256,7 @@ func DialProxyConnect(proxy string, dst string) (net.Conn, error) {
255256
return conn, nil
256257
}
257258

259+
// GetOriginalDst get the original destination address of a TCP connection before dstnat.
258260
func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) {
259261
file, err := conn.File()
260262
defer func(file *os.File) {
@@ -266,27 +268,15 @@ func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) {
266268
if err != nil {
267269
return nil, fmt.Errorf("failed to convert connection to file: %w", err)
268270
}
269-
var sockaddr [16]byte
270-
size := 16
271-
_, _, e := syscall.Syscall6(
272-
syscall.SYS_GETSOCKOPT,
273-
file.Fd(),
271+
return GetsockoptIPv4OriginalDst(
272+
int(file.Fd()),
274273
syscall.SOL_IP,
275274
80, // SO_ORIGINAL_DST
276-
uintptr(unsafe.Pointer(&sockaddr)),
277-
uintptr(unsafe.Pointer(&size)),
278-
0,
279275
)
280-
if e != 0 {
281-
return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: errno %d", e)
282-
}
283-
return &net.TCPAddr{
284-
IP: sockaddr[4:8],
285-
Port: int(binary.BigEndian.Uint16(sockaddr[2:4])),
286-
}, nil
287276
}
288277

289-
func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
278+
// RelayTCP relays data between the incoming TCP connection and the proxy connection.
279+
func RelayTCP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
290280
var closed atomic.Bool
291281
go func() {
292282
_, err := io.Copy(proxyConn, conn)
@@ -303,10 +293,10 @@ func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Log
303293
closed.Store(true)
304294
}
305295

306-
func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
307-
defer func() {
308-
_ = proxyConn.Close()
309-
}()
296+
// RelayHTTP relays a single HTTP request and response between a local connection and a proxy.
297+
// It modifies the Connection header to "close" in both the request and response.
298+
func RelayHTTP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
299+
defer proxyConn.Close()
310300
req, err := http.ReadRequest(bufio.NewReader(conn))
311301
if err != nil {
312302
logger.Error("failed to read HTTP request from connection", "error", err)
@@ -331,59 +321,65 @@ func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Lo
331321
}
332322
}
333323

324+
// HandleConn manages the incoming connections.
334325
func HandleConn(conn net.Conn, proxy string) {
335-
defer func() { _ = conn.Close() }()
326+
defer conn.Close()
336327
logger := slog.With("src", conn.RemoteAddr())
337328
dst, err := GetOriginalDst(conn.(*net.TCPConn))
338329
if err != nil {
339330
slog.Error("failed to get connection original destination", "error", err)
340331
return
341332
}
342333
logger = logger.With("original_dst", dst)
343-
var host string
344-
var relay func(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger)
345-
var dialProxy func(proxy string) (net.Conn, error)
346334
consigned := NewPrereadConn(conn)
347335
switch dst.Port {
348336
case 443:
349-
relay = RelayTcp
350-
sni, sniErr := PrereadSNI(consigned)
351-
if sniErr != nil {
352-
err = sniErr
337+
sni, err := PrereadSNI(consigned)
338+
if err != nil {
339+
logger.Error("failed to preread SNI from connection", "error", err)
340+
return
353341
} else {
354-
host = fmt.Sprintf("%s:%d", sni, dst.Port)
355-
dialProxy = func(proxy string) (net.Conn, error) { return DialProxyConnect(proxy, host) }
342+
host := fmt.Sprintf("%s:%d", sni, dst.Port)
343+
logger = logger.With("host", host)
344+
proxyConn, err := DialProxyConnect(proxy, host)
345+
if err != nil {
346+
logger.Error("failed to connect to http proxy", "error", err)
347+
return
348+
}
349+
logger.Info("relay TLS connection to proxy")
350+
RelayTCP(consigned, proxyConn, logger)
356351
}
357352
case 80:
358-
relay = RelayHttp
359-
host, err = PrereadHttpHost(consigned)
353+
host, err := PrereadHttpHost(consigned)
354+
if err != nil {
355+
logger.Error("failed to preread HTTP host from connection", "error", err)
356+
return
357+
}
360358
if !strings.Contains(host, ":") {
361359
host = fmt.Sprintf("%s:%d", host, dst.Port)
362360
}
363-
dialProxy = DialProxy
361+
logger = logger.With("host", host)
362+
proxyConn, err := DialProxy(proxy)
363+
if err != nil {
364+
logger.Error("failed to connect to http proxy", "error", err)
365+
return
366+
}
367+
logger.Info("relay HTTP connection to proxy")
368+
RelayHTTP(consigned, proxyConn, logger)
364369
default:
365370
logger.Error(fmt.Sprintf("unknown destination port: %d", dst.Port))
366371
return
367372
}
368-
if err != nil {
369-
logger.Error("failed to preread host from connection", "error", err)
370-
return
371-
}
372-
logger = logger.With("host", host)
373-
proxyConn, err := dialProxy(proxy)
374-
if err != nil {
375-
logger.Error("failed to connect to http proxy", "error", err)
376-
return
377-
}
378-
logger.Info("relay connection to http proxy")
379-
relay(consigned, proxyConn, logger)
380373
}
381374

382375
func main() {
383376
proxyFlag := flag.String("proxy", "", "upstream HTTP proxy address in the 'host:port' format")
384377
flag.Parse()
385378
listenAddr := &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 8443}
386-
listener, err := net.ListenTCP("tcp", listenAddr)
379+
ctx := context.Background()
380+
signal.NotifyContext(ctx, os.Interrupt)
381+
listenConfig := new(net.ListenConfig)
382+
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf(":8443"))
387383
if err != nil {
388384
log.Fatalf("failed to listen on %s", listenAddr.String())
389385
}

aproxy_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ func TestPrereadConn(t *testing.T) {
1111
remote, local := net.Pipe()
1212
go remote.Write([]byte("hello, world"))
1313
preread := &PrereadConn{conn: local}
14-
preread.StartPreread()
1514
buf := make([]byte, 5)
1615
_, err := preread.Read(buf)
1716
if err != nil {
@@ -22,7 +21,7 @@ func TestPrereadConn(t *testing.T) {
2221
if err != nil {
2322
t.Fatalf("Read failed during preread: %s", err)
2423
}
25-
preread.RestorePreread()
24+
preread.EndPreread()
2625
buf2 := make([]byte, 12)
2726
_, err = io.ReadFull(preread, buf2)
2827
if err != nil {

snap/snapcraft.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: aproxy
2-
version: 0.2.0
2+
version: 0.2.1
33
summary: Transparent proxy for HTTP and HTTPS/TLS connections.
44
description: |
55
Aproxy is a transparent proxy for HTTP and HTTPS/TLS connections. By

0 commit comments

Comments
 (0)