diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9018762 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.vscode/ +tunnel diff --git a/README.md b/README.md new file mode 100644 index 0000000..1f4ffdb --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +# Reverse TCP tunnel +For TCP-based services that do not have inbound end point, `Reverse TCP Tunnel` allows us to build a reverse TCP tunnel to open in bound access from a public end point. It virutally extends the TCP listening port to a remote machine in which a `Reverse TCP Tunnel` listener is running. + +We use a simple signaling protocol for tunnel establishment and connection multiplexing. + +## Tunnel listener +Tunnel listener runs in a public end point, when it receives a `ListenRequest` from service that needs tunnelled inbound access, it opens a dynamic TCP port at public interface and multiplexes traffic between service client and the service provider. + +Example command to launch tunnel listener +```bash +./tunnel -l 5555 +``` + +## Tunnel Connector +Tunnel connector runs wihin the private network boundary, it has access to services that requires tunnelled inbound access. + +Example command to establish a reverse tunnelling setup + +```bash +./tunnel -c localhost:5555 www.myservice.com:80 +``` + +## Build +``` +go build +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7e663c7 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/kelveny/tunnel + +go 1.16 + +require github.com/stretchr/testify v1.7.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b380ae4 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..72b96c3 --- /dev/null +++ b/protocol.go @@ -0,0 +1,317 @@ +package main + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" +) + +const ( + PDU_LISTEN_REQUEST = 1 + PDU_LISTEN_RESPONSE = 2 + PDU_TUNNEL_CONNECT_REQUEST = 3 + PDU_TUNNEL_CONNECT_RESPONSE = 4 + PDU_TUNNEL_DATA_INDICATION = 5 + PDU_TUNNEL_DISCONNECT_REQUEST = 6 + PDU_TUNNEL_DISCONNECT_RESPONSE = 7 +) + +type Serializable interface { + GetSerialType() int + GetSerialLength() uint32 + SerializeTo(w *bytes.Buffer) + SerializeFrom(r *bytes.Buffer) +} + +func serializeUInt32To(v uint32, w *bytes.Buffer) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, v) + w.Write(b) +} + +func serializeUInt32From(r *bytes.Buffer) uint32 { + b := make([]byte, 4) + r.Read(b) + return binary.BigEndian.Uint32(b) +} + +func getStringSerialLength(s string) uint32 { + return uint32(4 + len([]byte(s))) +} + +func serializeStringTo(s string, w *bytes.Buffer) { + l := uint32(len([]byte(s))) + serializeUInt32To(l, w) + w.Write([]byte(s)) +} + +func serializeStringFrom(r *bytes.Buffer) string { + l := serializeUInt32From(r) + + b := make([]byte, int(l)) + r.Read(b) + return string(b) +} + +func getPduSerialLength(pdu Serializable) uint32 { + return 1 + pdu.GetSerialLength() +} + +func serializePduTo(pdu Serializable, w *bytes.Buffer) { + w.WriteByte(byte(pdu.GetSerialType())) + pdu.SerializeTo(w) +} + +func serializePduFrom(r *bytes.Buffer) Serializable { + t, _ := r.ReadByte() + switch int(t) { + case PDU_LISTEN_REQUEST: + pdu := &ListenRequest{} + pdu.SerializeFrom(r) + return pdu + + case PDU_LISTEN_RESPONSE: + pdu := &ListenResponse{} + pdu.SerializeFrom(r) + return pdu + + case PDU_TUNNEL_CONNECT_REQUEST: + pdu := &TunnelConnectRequest{} + pdu.SerializeFrom(r) + return pdu + + case PDU_TUNNEL_CONNECT_RESPONSE: + pdu := &TunnelConnectResponse{} + pdu.SerializeFrom(r) + return pdu + + case PDU_TUNNEL_DATA_INDICATION: + pdu := &TunnelDataIndication{} + pdu.SerializeFrom(r) + return pdu + + case PDU_TUNNEL_DISCONNECT_REQUEST: + pdu := &TunnelDisconnectRequest{} + pdu.SerializeFrom(r) + return pdu + + case PDU_TUNNEL_DISCONNECT_RESPONSE: + pdu := &TunnelDisconnectResponse{} + pdu.SerializeFrom(r) + return pdu + } + + fmt.Printf("Invalid protocol data\n") + return nil +} + +func sendPdu(conn net.Conn, pdu Serializable) error { + l := getPduSerialLength(pdu) + + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, l) + _, err := conn.Write(b) + if err != nil { + return err + } + + buf := bytes.NewBuffer(nil) + serializePduTo(pdu, buf) + + _, err = conn.Write(buf.Bytes()) + + return err +} + +///////////////////////////////////////////////////////////////////////////// + +type ListenRequest struct { + proxyAddress string + proxyPort int +} + +func (pdu *ListenRequest) GetSerialType() int { + return PDU_LISTEN_REQUEST +} + +func (pdu *ListenRequest) GetSerialLength() uint32 { + return 4 + getStringSerialLength(pdu.proxyAddress) +} + +func (pdu *ListenRequest) SerializeTo(w *bytes.Buffer) { + serializeStringTo(pdu.proxyAddress, w) + serializeUInt32To(uint32(pdu.proxyPort), w) +} + +func (pdu *ListenRequest) SerializeFrom(r *bytes.Buffer) { + pdu.proxyAddress = serializeStringFrom(r) + pdu.proxyPort = int(serializeUInt32From(r)) +} + +///////////////////////////////////////////////////////////////////////////// + +type ListenResponse struct { + proxyAddress string + proxyPort int + tunnelAddress string + tunnelPort int +} + +func (pdu *ListenResponse) GetSerialType() int { + return PDU_LISTEN_RESPONSE +} + +func (pdu *ListenResponse) GetSerialLength() uint32 { + return 8 + getStringSerialLength(pdu.proxyAddress) + getStringSerialLength(pdu.tunnelAddress) +} + +func (pdu *ListenResponse) SerializeTo(w *bytes.Buffer) { + serializeStringTo(pdu.proxyAddress, w) + serializeUInt32To(uint32(pdu.proxyPort), w) + serializeStringTo(pdu.tunnelAddress, w) + serializeUInt32To(uint32(pdu.tunnelPort), w) +} + +func (pdu *ListenResponse) SerializeFrom(r *bytes.Buffer) { + pdu.proxyAddress = serializeStringFrom(r) + pdu.proxyPort = int(serializeUInt32From(r)) + pdu.tunnelAddress = serializeStringFrom(r) + pdu.tunnelPort = int(serializeUInt32From(r)) +} + +///////////////////////////////////////////////////////////////////////////// + +// listener -> proxy +type TunnelConnectRequest struct { + dataConnectionHandle uint32 + clientAddress string + + proxyAddress string + proxyPort int +} + +func (pdu *TunnelConnectRequest) GetSerialType() int { + return PDU_TUNNEL_CONNECT_REQUEST +} + +func (pdu *TunnelConnectRequest) GetSerialLength() uint32 { + return 4 + + getStringSerialLength(pdu.clientAddress) + + getStringSerialLength(pdu.proxyAddress) + + 4 +} + +func (pdu *TunnelConnectRequest) SerializeTo(w *bytes.Buffer) { + serializeUInt32To(uint32(pdu.dataConnectionHandle), w) + serializeStringTo(pdu.clientAddress, w) + serializeStringTo(pdu.proxyAddress, w) + serializeUInt32To(uint32(pdu.proxyPort), w) +} + +func (pdu *TunnelConnectRequest) SerializeFrom(r *bytes.Buffer) { + pdu.dataConnectionHandle = Handle(serializeUInt32From(r)) + pdu.clientAddress = serializeStringFrom(r) + pdu.proxyAddress = serializeStringFrom(r) + pdu.proxyPort = int(serializeUInt32From(r)) +} + +///////////////////////////////////////////////////////////////////////////// + +type TunnelConnectResponse struct { + dataConnectionHandle uint32 + proxyConnectionHandle uint32 +} + +func (pdu *TunnelConnectResponse) GetSerialType() int { + return PDU_TUNNEL_CONNECT_RESPONSE +} + +func (pdu *TunnelConnectResponse) GetSerialLength() uint32 { + return 8 +} + +func (pdu *TunnelConnectResponse) SerializeTo(w *bytes.Buffer) { + serializeUInt32To(uint32(pdu.dataConnectionHandle), w) + serializeUInt32To(uint32(pdu.proxyConnectionHandle), w) +} + +func (pdu *TunnelConnectResponse) SerializeFrom(r *bytes.Buffer) { + pdu.dataConnectionHandle = serializeUInt32From(r) + pdu.proxyConnectionHandle = serializeUInt32From(r) +} + +///////////////////////////////////////////////////////////////////////////// + +type TunnelDataIndication struct { + peerConnectionHandle uint32 + data []byte +} + +func (pdu *TunnelDataIndication) GetSerialType() int { + return PDU_TUNNEL_DATA_INDICATION +} + +func (pdu *TunnelDataIndication) GetSerialLength() uint32 { + return uint32(4 + 4 + len(pdu.data)) +} + +func (pdu *TunnelDataIndication) SerializeTo(w *bytes.Buffer) { + serializeUInt32To(uint32(pdu.peerConnectionHandle), w) + serializeUInt32To(uint32(len(pdu.data)), w) + w.Write(pdu.data) +} + +func (pdu *TunnelDataIndication) SerializeFrom(r *bytes.Buffer) { + pdu.peerConnectionHandle = serializeUInt32From(r) + + l := serializeUInt32From(r) + pdu.data = make([]byte, int(l)) + r.Read(pdu.data) +} + +///////////////////////////////////////////////////////////////////////////// + +type TunnelDisconnectRequest struct { + peerConnectionHandle uint32 +} + +func (pdu *TunnelDisconnectRequest) GetSerialType() int { + return PDU_TUNNEL_DISCONNECT_REQUEST +} + +func (pdu *TunnelDisconnectRequest) GetSerialLength() uint32 { + return 4 +} + +func (pdu *TunnelDisconnectRequest) SerializeTo(w *bytes.Buffer) { + serializeUInt32To(uint32(pdu.peerConnectionHandle), w) +} + +func (pdu *TunnelDisconnectRequest) SerializeFrom(r *bytes.Buffer) { + pdu.peerConnectionHandle = serializeUInt32From(r) +} + +///////////////////////////////////////////////////////////////////////////// + +type TunnelDisconnectResponse struct { + peerConnectionHandle uint32 +} + +func (pdu *TunnelDisconnectResponse) GetSerialType() int { + return PDU_TUNNEL_DISCONNECT_RESPONSE +} + +func (pdu *TunnelDisconnectResponse) GetSerialLength() uint32 { + return 4 +} + +func (pdu *TunnelDisconnectResponse) SerializeTo(w *bytes.Buffer) { + serializeUInt32To(uint32(pdu.peerConnectionHandle), w) +} + +func (pdu *TunnelDisconnectResponse) SerializeFrom(r *bytes.Buffer) { + pdu.peerConnectionHandle = serializeUInt32From(r) +} + +///////////////////////////////////////////////////////////////////////////// diff --git a/protocol_test.go b/protocol_test.go new file mode 100644 index 0000000..59eee93 --- /dev/null +++ b/protocol_test.go @@ -0,0 +1,25 @@ +package main + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSerializePdu(t *testing.T) { + assert := require.New(t) + + pdu := &ListenRequest{ + proxyAddress: "www.google.com", + proxyPort: 443, + } + + b := bytes.NewBuffer(nil) + serializePduTo(pdu, b) + + pduClone := serializePduFrom(bytes.NewBuffer(b.Bytes())) + assert.True(pduClone != nil) + assert.True(pduClone.(*ListenRequest).proxyAddress == "www.google.com") + assert.True(pduClone.(*ListenRequest).proxyPort == 443) +} diff --git a/tunnel.go b/tunnel.go new file mode 100644 index 0000000..516d82d --- /dev/null +++ b/tunnel.go @@ -0,0 +1,474 @@ +package main + +import ( + "bytes" + "context" + "encoding/binary" + "flag" + "fmt" + "net" + "strconv" + "strings" + "sync" +) + +type Handle = uint32 + +///////////////////////////////////////////////////////////////////////////// + +type tunnelProvider struct { + lock sync.Mutex + + // map handle -> *TunnelConnection + tunnelConnections map[Handle]*TunnelConnection + + // map handle -> *DataConnection + dataConnections map[Handle]*DataConnection + + nextHandle Handle +} + +func newTunnelProvider() *tunnelProvider { + return &tunnelProvider{ + tunnelConnections: make(map[Handle]*TunnelConnection), + dataConnections: make(map[Handle]*DataConnection), + nextHandle: 1, + } +} + +func (p *tunnelProvider) getNextHandle() Handle { + p.lock.Lock() + defer p.lock.Unlock() + + return p.getNextHandleUnLocked() +} + +func (p *tunnelProvider) getNextHandleUnLocked() Handle { + r := p.nextHandle + p.nextHandle++ + + return r +} + +func (p *tunnelProvider) newTunnelConnection(conn net.Conn) *TunnelConnection { + ctx, cancel := context.WithCancel(context.Background()) + tc := &TunnelConnection{ + provider: p, + conn: conn, + ctx: ctx, + cancel: cancel, + } + + p.lock.Lock() + defer p.lock.Unlock() + + handle := p.getNextHandleUnLocked() + tc.handle = handle + + p.tunnelConnections[handle] = tc + return tc +} + +func (p *tunnelProvider) closeTunnelConnection(tc *TunnelConnection) { + p.lock.Lock() + defer p.lock.Unlock() + + delete(p.tunnelConnections, tc.handle) +} + +func (p *tunnelProvider) getTunnelConnection(handle Handle) *TunnelConnection { + p.lock.Lock() + defer p.lock.Unlock() + + if tc, ok := p.tunnelConnections[handle]; ok { + return tc + } + + return nil +} + +func (p *tunnelProvider) getAndClearTunnelConnection(handle Handle) *TunnelConnection { + p.lock.Lock() + defer p.lock.Unlock() + + if tc, ok := p.tunnelConnections[handle]; ok { + delete(p.tunnelConnections, handle) + return tc + } + + return nil +} + +func (p *tunnelProvider) newDataConnection(tc *TunnelConnection, conn net.Conn) *DataConnection { + ctx, cancel := context.WithCancel(context.Background()) + dc := &DataConnection{ + conn: conn, + + tunnelConnection: tc, + ctx: ctx, + cancel: cancel, + } + + p.lock.Lock() + defer p.lock.Unlock() + + handle := p.getNextHandleUnLocked() + dc.handle = handle + + p.dataConnections[handle] = dc + return dc +} + +func (p *tunnelProvider) closeDataConnection(dc *DataConnection, notifyPeer bool) { + dc = p.getAndClearDataConnection(dc.handle) + if dc != nil { + fmt.Printf("Close data connection, local handle: %d, peer handle: %d\n", + dc.handle, dc.peerHandle) + + dc.conn.Close() + + if notifyPeer { + pdu := &TunnelDisconnectRequest{ + peerConnectionHandle: dc.peerHandle, + } + sendPdu(dc.tunnelConnection.conn, pdu) + } + } +} + +func (p *tunnelProvider) getDataConnection(handle Handle) *DataConnection { + p.lock.Lock() + defer p.lock.Unlock() + + if dc, ok := p.dataConnections[handle]; ok { + return dc + } + + return nil +} + +func (p *tunnelProvider) startListener(port int) { + l, err := net.Listen("tcp4", fmt.Sprintf("0.0.0.0:%d", port)) + if err != nil { + fmt.Printf("TCP listen error: %v\n", err) + return + } + + go func() { + for { + conn, err := l.Accept() + if err != nil { + fmt.Printf("TCP accept error: %v\n", err) + break + } else { + tc := p.newTunnelConnection(conn) + tc.open() + } + } + + l.Close() + }() +} + +func (p *tunnelProvider) startConnector(providerAddress string) (*TunnelConnection, error) { + conn, err := net.Dial("tcp4", providerAddress) + if err != nil { + return nil, err + } + + tc := p.newTunnelConnection(conn) + tc.open() + + return tc, nil +} + +func (p *tunnelProvider) getAndClearDataConnection(handle Handle) *DataConnection { + p.lock.Lock() + defer p.lock.Unlock() + + if dc, ok := p.dataConnections[handle]; ok { + delete(p.dataConnections, handle) + return dc + } + + return nil +} + +func (p *tunnelProvider) onTunnelPacket(tc *TunnelConnection, data []byte) { + r := bytes.NewBuffer(data) + pdu := serializePduFrom(r) + if pdu != nil { + switch int(pdu.GetSerialType()) { + case PDU_LISTEN_REQUEST: + tc.onListenRequest(pdu.(*ListenRequest)) + + case PDU_LISTEN_RESPONSE: + tc.onListenResponse(pdu.(*ListenResponse)) + + case PDU_TUNNEL_CONNECT_REQUEST: + tc.onTunnelConnectRequest(pdu.(*TunnelConnectRequest)) + + case PDU_TUNNEL_CONNECT_RESPONSE: + tc.onTunnelConnectResponse(pdu.(*TunnelConnectResponse)) + + case PDU_TUNNEL_DATA_INDICATION: + tc.onTunnelDataIndication(pdu.(*TunnelDataIndication)) + + case PDU_TUNNEL_DISCONNECT_REQUEST: + tc.onTunnelDisconnectRequest(pdu.(*TunnelDisconnectRequest)) + + case PDU_TUNNEL_DISCONNECT_RESPONSE: + tc.onTunnelDisconnectResponse(pdu.(*TunnelDisconnectResponse)) + } + } +} + +///////////////////////////////////////////////////////////////////////////// + +type DataConnection struct { + conn net.Conn + handle Handle + peerHandle Handle + + tunnelConnection *TunnelConnection + ctx context.Context + cancel context.CancelFunc +} + +func (dc *DataConnection) open(peerHandle Handle) { + dc.peerHandle = peerHandle + + go func() { + b := make([]byte, 4096) + for { + sz, err := dc.conn.Read(b) + + if sz == 0 || err != nil { + dc.close(true) + return + } + + pdu := &TunnelDataIndication{ + peerConnectionHandle: dc.peerHandle, + data: b[0:sz], + } + + // multiplex through tunnel connection + sendPdu(dc.tunnelConnection.conn, pdu) + } + }() +} + +func (dc *DataConnection) close(notifyPeer bool) { + dc.tunnelConnection.provider.closeDataConnection(dc, notifyPeer) +} + +///////////////////////////////////////////////////////////////////////////// + +type TunnelConnection struct { + provider *tunnelProvider + conn net.Conn + handle Handle + + tunnelPort int + + proxyAddress string + proxyPort int + + ctx context.Context + cancel context.CancelFunc +} + +func (tc *TunnelConnection) startListenFor(proxyAddress string, proxyPort int) int { + tc.proxyAddress = proxyAddress + tc.proxyPort = proxyPort + + listener, _ := net.Listen("tcp4", ":0") + tc.tunnelPort = listener.Addr().(*net.TCPAddr).Port + + go func() { + for { + c, err := listener.Accept() + if err != nil { + return + } + + tc.onIncomingDataConnection(c) + } + }() + + return tc.tunnelPort +} + +func (tc *TunnelConnection) startTunnelFor(proxyAddress string, proxyPort int) { + tc.proxyAddress = proxyAddress + tc.proxyPort = proxyPort + + pdu := &ListenRequest{ + proxyAddress: proxyAddress, + proxyPort: proxyPort, + } + + sendPdu(tc.conn, pdu) +} + +func (tc *TunnelConnection) onListenRequest(pdu *ListenRequest) { + tunnelPort := tc.startListenFor(pdu.proxyAddress, pdu.proxyPort) + + responsePdu := &ListenResponse{ + tunnelAddress: "0.0.0.0", + tunnelPort: tunnelPort, + proxyAddress: pdu.proxyAddress, + proxyPort: pdu.proxyPort, + } + + sendPdu(tc.conn, responsePdu) +} + +func (tc *TunnelConnection) onListenResponse(pdu *ListenResponse) { + tc.tunnelPort = pdu.tunnelPort + + fmt.Printf("Tunnel port is open: %d\n", pdu.tunnelPort) +} + +func (tc *TunnelConnection) onTunnelConnectRequest(pdu *TunnelConnectRequest) { + conn, err := net.Dial("tcp4", fmt.Sprintf("%s:%d", tc.proxyAddress, tc.proxyPort)) + + if err != nil { + response := &TunnelDisconnectResponse{ + peerConnectionHandle: pdu.dataConnectionHandle, + } + + sendPdu(tc.conn, response) + return + } + + dc := tc.provider.newDataConnection(tc, conn) + dc.open(pdu.dataConnectionHandle) + + fmt.Printf("Open data connection to target %s:%d. local handle: %d, peer handle: %d\n", + tc.proxyAddress, tc.proxyPort, dc.handle, pdu.dataConnectionHandle) + + response := &TunnelConnectResponse{ + dataConnectionHandle: pdu.dataConnectionHandle, + proxyConnectionHandle: dc.handle, + } + sendPdu(tc.conn, response) +} + +func (tc *TunnelConnection) onTunnelConnectResponse(pdu *TunnelConnectResponse) { + if dc := tc.provider.getDataConnection(pdu.dataConnectionHandle); dc != nil { + dc.open(pdu.proxyConnectionHandle) + + fmt.Printf("Connect data connection to target %s:%d. local handle: %d, peer handle: %d\n", + tc.proxyAddress, tc.proxyPort, dc.handle, pdu.proxyConnectionHandle) + } +} + +func (tc *TunnelConnection) onTunnelDataIndication(pdu *TunnelDataIndication) { + if dc := tc.provider.getDataConnection(pdu.peerConnectionHandle); dc != nil { + _, err := dc.conn.Write(pdu.data) + + if err != nil { + dc.close(true) + } + } +} + +func (tc *TunnelConnection) onTunnelDisconnectRequest(pdu *TunnelDisconnectRequest) { + fmt.Printf("Tunnel disconnect request for local handle: %d\n", pdu.peerConnectionHandle) + + if dc := tc.provider.getDataConnection(pdu.peerConnectionHandle); dc != nil { + dc.close(false) + + response := &TunnelDisconnectResponse{ + peerConnectionHandle: dc.peerHandle, + } + sendPdu(tc.conn, response) + } +} + +func (tc *TunnelConnection) onTunnelDisconnectResponse(pdu *TunnelDisconnectResponse) { + fmt.Printf("Tunnel disconnect response for local handle: %d\n", pdu.peerConnectionHandle) + + if dc := tc.provider.getDataConnection(pdu.peerConnectionHandle); dc != nil { + dc.close(false) + } +} + +func (tc *TunnelConnection) onIncomingDataConnection(conn net.Conn) { + dc := tc.provider.newDataConnection(tc, conn) + + req := &TunnelConnectRequest{ + dataConnectionHandle: dc.handle, + clientAddress: "0.0.0.0", // TODO + + proxyAddress: tc.proxyAddress, + proxyPort: tc.proxyPort, + } + + sendPdu(tc.conn, req) +} + +func (tc *TunnelConnection) open() { + go func() { + for { + b := make([]byte, 4) + len, err := tc.conn.Read(b) + if len < 4 || err != nil { + tc.provider.closeTunnelConnection(tc) + break + } + + dataLength := binary.BigEndian.Uint32(b) + data := make([]byte, dataLength) + len, err = tc.conn.Read(data) + + if len < int(dataLength) || err != nil { + tc.provider.closeTunnelConnection(tc) + break + } + + tc.provider.onTunnelPacket(tc, data) + } + }() +} + +func main() { + port := flag.Int("l", 0, "Tunnel provider signaling port") + providerAddress := flag.String("c", "", "Tunnel provider signaling address") + targetAddress := flag.String("t", "", "Target address to be tunnelled") + + flag.Parse() + + p := newTunnelProvider() + + if *port != 0 { + p.startListener(*port) + + // no graceful shutdown yet + select {} + } else { + if len(*providerAddress) == 0 || len(*targetAddress) == 0 { + fmt.Printf("Usage: tunnel [-l] [[-c] [-t]]\n") + return + } + + tc, err := p.startConnector(*providerAddress) + if err != nil { + fmt.Printf("Error: %s\n", err) + return + } + + addr := strings.Split(*targetAddress, ":") + targetPort := 443 + if len(addr) > 1 { + targetPort, _ = strconv.Atoi(addr[1]) + } + + tc.startTunnelFor(addr[0], targetPort) + + // no graceful shutdown yet + select {} + } +}