Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply suggestions from Ben Hoyt's review #5

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

publish:
name: Publish Aproxy
runs-on: [ self-hosted, linux, x64, large ]
runs-on: ubuntu-latest
needs: [ tests, integration-tests ]

steps:
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ jobs:
with:
go-version: 1.21

- name: Ensure No Formatting Changes
run: |
go fmt ./...
git diff --exit-code

- name: Build and Test
run: |
go test -race ./...
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,17 @@ You can inspect the access logs of aproxy using:
```bash
sudo snap logs aproxy.aproxy -n=all
```

## Running from Source

To run this application directly from the source code, you'll need to have Go
1.21 installed on your system.

Follow these steps to get started:

```bash
git clone https://github.com/canonical/aproxy.git
cd aproxy
go mod download
go run . --proxy=squid.internal:3128
```
168 changes: 82 additions & 86 deletions aproxy.go
Original file line number Diff line number Diff line change
@@ -1,88 +1,84 @@
//go:build linux

package main

import (
"bufio"
"context"
"encoding/binary"
"errors"
"flag"
"fmt"
"golang.org/x/crypto/cryptobyte"
"io"
"log"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"unsafe"

"golang.org/x/crypto/cryptobyte"
)

var version = "0.2.0"
var version = "0.2.1"

// PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection.
// Any Read before the EndPreread can be undone and read again by calling the EndPreread function.
type PrereadConn struct {
prereadStarted bool
prereadEnded bool
prereadBuf []byte
mu sync.Mutex
conn net.Conn
}

func (c *PrereadConn) StartPreread() {
c.mu.Lock()
defer c.mu.Unlock()
if c.prereadStarted {
panic("call StartPreread after preread has already started or ended")
}
c.prereadStarted = true
ended bool
buf []byte
mu sync.Mutex
conn net.Conn
}

func (c *PrereadConn) RestorePreread() {
// EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again.
// EndPreread can be only called once.
func (c *PrereadConn) EndPreread() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.prereadStarted || c.prereadEnded {
panic("call RestorePreread after preread has ended or hasn't started")
if c.ended {
panic("call EndPreread after preread has ended or hasn't started")
}
c.prereadEnded = true
c.ended = true
}

// Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread.
func (c *PrereadConn) Read(p []byte) (n int, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.prereadEnded {
n = copy(p, c.prereadBuf)
bufLen := len(c.prereadBuf)
c.prereadBuf = c.prereadBuf[n:]
if c.ended {
n = copy(p, c.buf)
bufLen := len(c.buf)
c.buf = c.buf[n:]
if n == len(p) || (bufLen > 0 && bufLen == n) {
return n, nil
}
rn, err := c.conn.Read(p[n:])
return rn + n, err
}
if c.prereadStarted {
} else {
n, err = c.conn.Read(p)
c.prereadBuf = append(c.prereadBuf, p[:n]...)
c.buf = append(c.buf, p[:n]...)
return n, err
}
return c.conn.Read(p)
}

// Write writes data to the underlying connection.
func (c *PrereadConn) Write(p []byte) (n int, err error) {
return c.conn.Write(p)
}

// NewPrereadConn wraps the network connection and return a *PrereadConn.
// It's recommended to not touch the original connection after wrapped.
func NewPrereadConn(conn net.Conn) *PrereadConn {
return &PrereadConn{conn: conn}
}

// PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection.
func PrereadSNI(conn *PrereadConn) (_ string, err error) {
conn.StartPreread()
defer conn.RestorePreread()
defer conn.EndPreread()
defer func() {
if err != nil {
err = fmt.Errorf("failed to preread TLS client hello: %w", err)
Expand Down Expand Up @@ -114,11 +110,11 @@ func PrereadSNI(conn *PrereadConn) (_ string, err error) {

func extractSNI(data []byte) (string, error) {
s := cryptobyte.String(data)
var version uint16
var random []byte
var sessionId []byte
var compressionMethods []byte
var cipherSuites []uint16
var (
version uint16
random []byte
sessionId []byte
)

if !s.Skip(9) ||
!s.ReadUint16(&version) || !s.ReadBytes(&random, 32) ||
Expand All @@ -130,6 +126,8 @@ func extractSNI(data []byte) (string, error) {
if !s.ReadUint16LengthPrefixed(&cipherSuitesData) {
return "", fmt.Errorf("failed to parse TLS client hello cipher suites")
}

var cipherSuites []uint16
for !cipherSuitesData.Empty() {
var suite uint16
if !cipherSuitesData.ReadUint16(&suite) {
Expand All @@ -138,6 +136,7 @@ func extractSNI(data []byte) (string, error) {
cipherSuites = append(cipherSuites, suite)
}

var compressionMethods []byte
if !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&compressionMethods)) {
return "", fmt.Errorf("failed to parse TLS client hello compression methods")
}
Expand Down Expand Up @@ -191,15 +190,15 @@ func extractSNI(data []byte) (string, error) {
return finalServerName, nil
}

// PrereadHttpHost pre-reads the HTTP Host header from an HTTP connection.
func PrereadHttpHost(conn *PrereadConn) (_ string, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to preread HTTP request: %w", err)
}
}()

conn.StartPreread()
defer conn.RestorePreread()
defer conn.EndPreread()
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
return "", err
Expand All @@ -211,6 +210,7 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) {
return host, nil
}

// DialProxy dials the TCP connection to the proxy.
func DialProxy(proxy string) (net.Conn, error) {
proxyAddr, err := net.ResolveTCPAddr("tcp", proxy)
if err != nil {
Expand All @@ -223,6 +223,7 @@ func DialProxy(proxy string) (net.Conn, error) {
return conn, nil
}

// DialProxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy.
func DialProxyConnect(proxy string, dst string) (net.Conn, error) {
conn, err := DialProxy(proxy)
if err != nil {
Expand Down Expand Up @@ -255,6 +256,7 @@ func DialProxyConnect(proxy string, dst string) (net.Conn, error) {
return conn, nil
}

// GetOriginalDst get the original destination address of a TCP connection before dstnat.
func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) {
file, err := conn.File()
defer func(file *os.File) {
Expand All @@ -266,27 +268,15 @@ func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) {
if err != nil {
return nil, fmt.Errorf("failed to convert connection to file: %w", err)
}
var sockaddr [16]byte
size := 16
_, _, e := syscall.Syscall6(
syscall.SYS_GETSOCKOPT,
file.Fd(),
return GetsockoptIPv4OriginalDst(
int(file.Fd()),
syscall.SOL_IP,
80, // SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sockaddr)),
uintptr(unsafe.Pointer(&size)),
0,
)
if e != 0 {
return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: errno %d", e)
}
return &net.TCPAddr{
IP: sockaddr[4:8],
Port: int(binary.BigEndian.Uint16(sockaddr[2:4])),
}, nil
}

func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
// RelayTCP relays data between the incoming TCP connection and the proxy connection.
func RelayTCP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
var closed atomic.Bool
go func() {
_, err := io.Copy(proxyConn, conn)
Expand All @@ -303,10 +293,10 @@ func RelayTcp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Log
closed.Store(true)
}

func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
defer func() {
_ = proxyConn.Close()
}()
// RelayHTTP relays a single HTTP request and response between a local connection and a proxy.
// It modifies the Connection header to "close" in both the request and response.
func RelayHTTP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) {
defer proxyConn.Close()
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
logger.Error("failed to read HTTP request from connection", "error", err)
Expand All @@ -331,59 +321,65 @@ func RelayHttp(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Lo
}
}

// HandleConn manages the incoming connections.
func HandleConn(conn net.Conn, proxy string) {
defer func() { _ = conn.Close() }()
defer conn.Close()
logger := slog.With("src", conn.RemoteAddr())
dst, err := GetOriginalDst(conn.(*net.TCPConn))
if err != nil {
slog.Error("failed to get connection original destination", "error", err)
return
}
logger = logger.With("original_dst", dst)
var host string
var relay func(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger)
var dialProxy func(proxy string) (net.Conn, error)
consigned := NewPrereadConn(conn)
switch dst.Port {
case 443:
relay = RelayTcp
sni, sniErr := PrereadSNI(consigned)
if sniErr != nil {
err = sniErr
sni, err := PrereadSNI(consigned)
if err != nil {
logger.Error("failed to preread SNI from connection", "error", err)
return
} else {
host = fmt.Sprintf("%s:%d", sni, dst.Port)
dialProxy = func(proxy string) (net.Conn, error) { return DialProxyConnect(proxy, host) }
host := fmt.Sprintf("%s:%d", sni, dst.Port)
logger = logger.With("host", host)
proxyConn, err := DialProxyConnect(proxy, host)
if err != nil {
logger.Error("failed to connect to http proxy", "error", err)
return
}
logger.Info("relay TLS connection to proxy")
RelayTCP(consigned, proxyConn, logger)
}
case 80:
relay = RelayHttp
host, err = PrereadHttpHost(consigned)
host, err := PrereadHttpHost(consigned)
if err != nil {
logger.Error("failed to preread HTTP host from connection", "error", err)
return
}
if !strings.Contains(host, ":") {
host = fmt.Sprintf("%s:%d", host, dst.Port)
}
dialProxy = DialProxy
logger = logger.With("host", host)
proxyConn, err := DialProxy(proxy)
if err != nil {
logger.Error("failed to connect to http proxy", "error", err)
return
}
logger.Info("relay HTTP connection to proxy")
RelayHTTP(consigned, proxyConn, logger)
default:
logger.Error(fmt.Sprintf("unknown destination port: %d", dst.Port))
return
}
if err != nil {
logger.Error("failed to preread host from connection", "error", err)
return
}
logger = logger.With("host", host)
proxyConn, err := dialProxy(proxy)
if err != nil {
logger.Error("failed to connect to http proxy", "error", err)
return
}
logger.Info("relay connection to http proxy")
relay(consigned, proxyConn, logger)
}

func main() {
proxyFlag := flag.String("proxy", "", "upstream HTTP proxy address in the 'host:port' format")
flag.Parse()
listenAddr := &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 8443}
listener, err := net.ListenTCP("tcp", listenAddr)
ctx := context.Background()
signal.NotifyContext(ctx, os.Interrupt)
listenConfig := new(net.ListenConfig)
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf(":8443"))
if err != nil {
log.Fatalf("failed to listen on %s", listenAddr.String())
}
Expand Down
3 changes: 1 addition & 2 deletions aproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ func TestPrereadConn(t *testing.T) {
remote, local := net.Pipe()
go remote.Write([]byte("hello, world"))
preread := &PrereadConn{conn: local}
preread.StartPreread()
buf := make([]byte, 5)
_, err := preread.Read(buf)
if err != nil {
Expand All @@ -22,7 +21,7 @@ func TestPrereadConn(t *testing.T) {
if err != nil {
t.Fatalf("Read failed during preread: %s", err)
}
preread.RestorePreread()
preread.EndPreread()
buf2 := make([]byte, 12)
_, err = io.ReadFull(preread, buf2)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion snap/snapcraft.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: aproxy
version: 0.2.0
version: 0.2.1
summary: Transparent proxy for HTTP and HTTPS/TLS connections.
description: |
Aproxy is a transparent proxy for HTTP and HTTPS/TLS connections. By
Expand Down
Loading
Loading