diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index cccf4a0..97cc60f 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -13,7 +13,7 @@ jobs: publish: name: Publish Aproxy - runs-on: [ self-hosted, linux, x64, large ] + runs-on: ubuntu-latest needs: [ tests, integration-tests ] steps: diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a247eb0..c59b4d7 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,6 +17,11 @@ jobs: with: go-version: 1.21 + - name: Ensure No Formatting Changes + run: | + go fmt ./... + git diff --exit-code + - name: Build and Test run: | go test -race ./... diff --git a/README.md b/README.md index 227b595..8495a12 100644 --- a/README.md +++ b/README.md @@ -43,3 +43,17 @@ You can inspect the access logs of aproxy using: ```bash sudo snap logs aproxy.aproxy -n=all ``` + +## Running from Source + +To run this application directly from the source code, you'll need to have Go +1.21 installed on your system. + +Follow these steps to get started: + +```bash +git clone https://github.com/canonical/aproxy.git +cd aproxy +go mod download +go run . --proxy=squid.internal:3128 +``` diff --git a/aproxy.go b/aproxy.go index 5cc599a..80d3011 100644 --- a/aproxy.go +++ b/aproxy.go @@ -1,14 +1,12 @@ -//go:build linux - package main import ( "bufio" + "context" "encoding/binary" "errors" "flag" "fmt" - "golang.org/x/crypto/cryptobyte" "io" "log" "log/slog" @@ -16,73 +14,71 @@ import ( "net/http" "net/url" "os" + "os/signal" "strings" "sync" "sync/atomic" "syscall" - "unsafe" + + "golang.org/x/crypto/cryptobyte" ) -var version = "0.2.0" +var version = "0.2.1" +// PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection. +// Any Read before the EndPreread can be undone and read again by calling the EndPreread function. type PrereadConn struct { - prereadStarted bool - prereadEnded bool - prereadBuf []byte - mu sync.Mutex - conn net.Conn -} - -func (c *PrereadConn) StartPreread() { - c.mu.Lock() - defer c.mu.Unlock() - if c.prereadStarted { - panic("call StartPreread after preread has already started or ended") - } - c.prereadStarted = true + ended bool + buf []byte + mu sync.Mutex + conn net.Conn } -func (c *PrereadConn) RestorePreread() { +// EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again. +// EndPreread can be only called once. +func (c *PrereadConn) EndPreread() { c.mu.Lock() defer c.mu.Unlock() - if !c.prereadStarted || c.prereadEnded { - panic("call RestorePreread after preread has ended or hasn't started") + if c.ended { + panic("call EndPreread after preread has ended or hasn't started") } - c.prereadEnded = true + c.ended = true } +// Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread. func (c *PrereadConn) Read(p []byte) (n int, err error) { c.mu.Lock() defer c.mu.Unlock() - if c.prereadEnded { - n = copy(p, c.prereadBuf) - bufLen := len(c.prereadBuf) - c.prereadBuf = c.prereadBuf[n:] + if c.ended { + n = copy(p, c.buf) + bufLen := len(c.buf) + c.buf = c.buf[n:] if n == len(p) || (bufLen > 0 && bufLen == n) { return n, nil } rn, err := c.conn.Read(p[n:]) return rn + n, err - } - if c.prereadStarted { + } else { n, err = c.conn.Read(p) - c.prereadBuf = append(c.prereadBuf, p[:n]...) + c.buf = append(c.buf, p[:n]...) return n, err } - return c.conn.Read(p) } +// Write writes data to the underlying connection. func (c *PrereadConn) Write(p []byte) (n int, err error) { return c.conn.Write(p) } +// NewPrereadConn wraps the network connection and return a *PrereadConn. +// It's recommended to not touch the original connection after wrapped. func NewPrereadConn(conn net.Conn) *PrereadConn { return &PrereadConn{conn: conn} } +// PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection. func PrereadSNI(conn *PrereadConn) (_ string, err error) { - conn.StartPreread() - defer conn.RestorePreread() + defer conn.EndPreread() defer func() { if err != nil { err = fmt.Errorf("failed to preread TLS client hello: %w", err) @@ -114,11 +110,11 @@ func PrereadSNI(conn *PrereadConn) (_ string, err error) { func extractSNI(data []byte) (string, error) { s := cryptobyte.String(data) - var version uint16 - var random []byte - var sessionId []byte - var compressionMethods []byte - var cipherSuites []uint16 + var ( + version uint16 + random []byte + sessionId []byte + ) if !s.Skip(9) || !s.ReadUint16(&version) || !s.ReadBytes(&random, 32) || @@ -130,6 +126,8 @@ func extractSNI(data []byte) (string, error) { if !s.ReadUint16LengthPrefixed(&cipherSuitesData) { return "", fmt.Errorf("failed to parse TLS client hello cipher suites") } + + var cipherSuites []uint16 for !cipherSuitesData.Empty() { var suite uint16 if !cipherSuitesData.ReadUint16(&suite) { @@ -138,6 +136,7 @@ func extractSNI(data []byte) (string, error) { cipherSuites = append(cipherSuites, suite) } + var compressionMethods []byte if !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&compressionMethods)) { return "", fmt.Errorf("failed to parse TLS client hello compression methods") } @@ -191,6 +190,7 @@ func extractSNI(data []byte) (string, error) { return finalServerName, nil } +// PrereadHttpHost pre-reads the HTTP Host header from an HTTP connection. func PrereadHttpHost(conn *PrereadConn) (_ string, err error) { defer func() { if err != nil { @@ -198,8 +198,7 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) { } }() - conn.StartPreread() - defer conn.RestorePreread() + defer conn.EndPreread() req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { return "", err @@ -211,6 +210,7 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) { return host, nil } +// DialProxy dials the TCP connection to the proxy. func DialProxy(proxy string) (net.Conn, error) { proxyAddr, err := net.ResolveTCPAddr("tcp", proxy) if err != nil { @@ -223,6 +223,7 @@ func DialProxy(proxy string) (net.Conn, error) { return conn, nil } +// DialProxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy. func DialProxyConnect(proxy string, dst string) (net.Conn, error) { conn, err := DialProxy(proxy) if err != nil { @@ -255,6 +256,7 @@ func DialProxyConnect(proxy string, dst string) (net.Conn, error) { return conn, nil } +// GetOriginalDst get the original destination address of a TCP connection before dstnat. func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) { file, err := conn.File() defer func(file *os.File) { @@ -266,27 +268,15 @@ func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) { if err != nil { return nil, fmt.Errorf("failed to convert connection to file: %w", err) } - var sockaddr [16]byte - size := 16 - _, _, e := syscall.Syscall6( - syscall.SYS_GETSOCKOPT, - file.Fd(), + return GetsockoptIPv4OriginalDst( + int(file.Fd()), syscall.SOL_IP, 80, // SO_ORIGINAL_DST - uintptr(unsafe.Pointer(&sockaddr)), - uintptr(unsafe.Pointer(&size)), - 0, ) - if e != 0 { - return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: errno %d", e) - } - return &net.TCPAddr{ - IP: sockaddr[4:8], - Port: int(binary.BigEndian.Uint16(sockaddr[2:4])), - }, nil } -func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { +// RelayTCP relays data between the incoming TCP connection and the proxy connection. +func RelayTCP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { var closed atomic.Bool go func() { _, err := io.Copy(proxyConn, conn) @@ -303,10 +293,10 @@ func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Log closed.Store(true) } -func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { - defer func() { - _ = proxyConn.Close() - }() +// RelayHTTP relays a single HTTP request and response between a local connection and a proxy. +// It modifies the Connection header to "close" in both the request and response. +func RelayHTTP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { + defer proxyConn.Close() req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { logger.Error("failed to read HTTP request from connection", "error", err) @@ -331,8 +321,9 @@ func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Lo } } +// HandleConn manages the incoming connections. func HandleConn(conn net.Conn, proxy string) { - defer func() { _ = conn.Close() }() + defer conn.Close() logger := slog.With("src", conn.RemoteAddr()) dst, err := GetOriginalDst(conn.(*net.TCPConn)) if err != nil { @@ -340,50 +331,55 @@ func HandleConn(conn net.Conn, proxy string) { return } logger = logger.With("original_dst", dst) - var host string - var relay func(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) - var dialProxy func(proxy string) (net.Conn, error) consigned := NewPrereadConn(conn) switch dst.Port { case 443: - relay = RelayTcp - sni, sniErr := PrereadSNI(consigned) - if sniErr != nil { - err = sniErr + sni, err := PrereadSNI(consigned) + if err != nil { + logger.Error("failed to preread SNI from connection", "error", err) + return } else { - host = fmt.Sprintf("%s:%d", sni, dst.Port) - dialProxy = func(proxy string) (net.Conn, error) { return DialProxyConnect(proxy, host) } + host := fmt.Sprintf("%s:%d", sni, dst.Port) + logger = logger.With("host", host) + proxyConn, err := DialProxyConnect(proxy, host) + if err != nil { + logger.Error("failed to connect to http proxy", "error", err) + return + } + logger.Info("relay TLS connection to proxy") + RelayTCP(consigned, proxyConn, logger) } case 80: - relay = RelayHttp - host, err = PrereadHttpHost(consigned) + host, err := PrereadHttpHost(consigned) + if err != nil { + logger.Error("failed to preread HTTP host from connection", "error", err) + return + } if !strings.Contains(host, ":") { host = fmt.Sprintf("%s:%d", host, dst.Port) } - dialProxy = DialProxy + logger = logger.With("host", host) + proxyConn, err := DialProxy(proxy) + if err != nil { + logger.Error("failed to connect to http proxy", "error", err) + return + } + logger.Info("relay HTTP connection to proxy") + RelayHTTP(consigned, proxyConn, logger) default: logger.Error(fmt.Sprintf("unknown destination port: %d", dst.Port)) return } - if err != nil { - logger.Error("failed to preread host from connection", "error", err) - return - } - logger = logger.With("host", host) - proxyConn, err := dialProxy(proxy) - if err != nil { - logger.Error("failed to connect to http proxy", "error", err) - return - } - logger.Info("relay connection to http proxy") - relay(consigned, proxyConn, logger) } func main() { proxyFlag := flag.String("proxy", "", "upstream HTTP proxy address in the 'host:port' format") flag.Parse() listenAddr := &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 8443} - listener, err := net.ListenTCP("tcp", listenAddr) + ctx := context.Background() + signal.NotifyContext(ctx, os.Interrupt) + listenConfig := new(net.ListenConfig) + listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf(":8443")) if err != nil { log.Fatalf("failed to listen on %s", listenAddr.String()) } diff --git a/aproxy_test.go b/aproxy_test.go index a25d1be..b598b5e 100644 --- a/aproxy_test.go +++ b/aproxy_test.go @@ -11,7 +11,6 @@ func TestPrereadConn(t *testing.T) { remote, local := net.Pipe() go remote.Write([]byte("hello, world")) preread := &PrereadConn{conn: local} - preread.StartPreread() buf := make([]byte, 5) _, err := preread.Read(buf) if err != nil { @@ -22,7 +21,7 @@ func TestPrereadConn(t *testing.T) { if err != nil { t.Fatalf("Read failed during preread: %s", err) } - preread.RestorePreread() + preread.EndPreread() buf2 := make([]byte, 12) _, err = io.ReadFull(preread, buf2) if err != nil { diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml index 1bb2419..6802db3 100644 --- a/snap/snapcraft.yaml +++ b/snap/snapcraft.yaml @@ -1,5 +1,5 @@ name: aproxy -version: 0.2.0 +version: 0.2.1 summary: Transparent proxy for HTTP and HTTPS/TLS connections. description: | Aproxy is a transparent proxy for HTTP and HTTPS/TLS connections. By diff --git a/syscall_linux.go b/syscall_linux.go new file mode 100644 index 0000000..1b15e4a --- /dev/null +++ b/syscall_linux.go @@ -0,0 +1,32 @@ +//go:build linux + +package main + +import ( + "encoding/binary" + "fmt" + "net" + "syscall" + "unsafe" +) + +func GetsockoptIPv4OriginalDst(fd, level, opt int) (*net.TCPAddr, error) { + var sockaddr [16]byte + size := 16 + _, _, e := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(fd), + uintptr(level), + uintptr(opt), + uintptr(unsafe.Pointer(&sockaddr)), + uintptr(unsafe.Pointer(&size)), + 0, + ) + if e != 0 { + return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: errno %d", e) + } + return &net.TCPAddr{ + IP: sockaddr[4:8], + Port: int(binary.BigEndian.Uint16(sockaddr[2:4])), + }, nil +}