From 3d56ada3556a986cb7e5da56ede4115e71f098cf Mon Sep 17 00:00:00 2001 From: bnm3k <55702585+bnm3k@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:19:16 +0300 Subject: [PATCH] Fix bug for parsing IPv6 addresses (#17) IPv6 addresses contain colons ':' and since sshtunnel uses strings.Split(addr, ":"), to spit host and port, it ends up resulting in faulty splits. Using net.SplitHostPort fixes this since it handles IPv6 addresses correctly. If the user did not supply a port, then endpoint.Host is left as is. This includes a partially breaking change where NewEdnpoint and NewSSHTunnel will now return an error. If you are not using IPv6 it is safe to ignore this error. --- endpoint.go | 24 +++++++++++++---- endpoint_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ ssh_tunnel.go | 20 ++++++++++---- 3 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 endpoint_test.go diff --git a/endpoint.go b/endpoint.go index 75b4fc3..7f0bf34 100644 --- a/endpoint.go +++ b/endpoint.go @@ -2,6 +2,7 @@ package sshtunnel import ( "fmt" + "net" "strconv" "strings" ) @@ -12,7 +13,11 @@ type Endpoint struct { User string } -func NewEndpoint(s string) *Endpoint { +// NewEndpoint creates an Endpoint from a string that contains a user, host and +// port. Both User and Port are optional (depending on context). The host can +// be a domain name, IPv4 address or IPv6 address. If it's an IPv6, it must be +// enclosed in square brackets +func NewEndpoint(s string) (*Endpoint, error) { endpoint := &Endpoint{ Host: s, } @@ -22,12 +27,21 @@ func NewEndpoint(s string) *Endpoint { endpoint.Host = parts[1] } - if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 { - endpoint.Host = parts[0] - endpoint.Port, _ = strconv.Atoi(parts[1]) + host, port, err := net.SplitHostPort(endpoint.Host) + if err != nil { + // if error results from missing port in address, we ignore the error + // since either we'll use a random port assigned by the OS or set a + // suitable default directly, e.g. port 22 for SSH. Also worth noting, + // the host is set to the rest of the string since no port is provided + if !strings.Contains(err.Error(), "missing port in address") { + return nil, err + } + } else { + endpoint.Host = host + endpoint.Port, _ = strconv.Atoi(port) } - return endpoint + return endpoint, nil } func (endpoint *Endpoint) String() string { diff --git a/endpoint_test.go b/endpoint_test.go new file mode 100644 index 0000000..e1ed416 --- /dev/null +++ b/endpoint_test.go @@ -0,0 +1,69 @@ +package sshtunnel_test + +import ( + "reflect" + "testing" + + "github.com/elliotchance/sshtunnel" +) + +func TestCreateEndpoint(t *testing.T) { + // these are test cases for which we expect no error to occur when + // constructing endpoints i.e. they should be correct + testCases := []struct { + input string + expectedEndpoint *sshtunnel.Endpoint + }{ + { + "localhost:9000", + &sshtunnel.Endpoint{ + Host: "localhost", + Port: 9000, + User: "", + }, + }, + { + "ec2-user@jumpbox.us-east-1.mydomain.com", + &sshtunnel.Endpoint{ + Host: "jumpbox.us-east-1.mydomain.com", + Port: 0, + User: "ec2-user", + }, + }, + { + "dqrsdfdssdfx.us-east-1.redshift.amazonaws.com:5439", + &sshtunnel.Endpoint{ + Host: "dqrsdfdssdfx.us-east-1.redshift.amazonaws.com", + Port: 5439, + User: "", + }, + }, + { + "admin@1.2.3.4:22", // IPv4 address + &sshtunnel.Endpoint{ + Host: "1.2.3.4", + Port: 22, + User: "admin", + }, + }, + { + "admin@[2001:db8:1::ab9:C0A8:102]:22", // IPv6 address + &sshtunnel.Endpoint{ + Host: "2001:db8:1::ab9:C0A8:102", + Port: 22, + User: "admin", + }, + }, + } + for i, tc := range testCases { + got, err := sshtunnel.NewEndpoint(tc.input) + if err != nil { + t.Errorf("unexpected error for correct input '%s': %v", + tc.input, err) + } + if !reflect.DeepEqual(got, tc.expectedEndpoint) { + t.Errorf("For test case %d, expected: %+v, got: %+v", + i, *tc.expectedEndpoint, *got) + } + } +} diff --git a/ssh_tunnel.go b/ssh_tunnel.go index 0053f5b..85c6e91 100644 --- a/ssh_tunnel.go +++ b/ssh_tunnel.go @@ -147,15 +147,25 @@ func (tunnel *SSHTunnel) Close() { } // NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port. -func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) *SSHTunnel { +func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) (*SSHTunnel, error) { - localEndpoint := NewEndpoint("localhost:" + localport) + localEndpoint, err := NewEndpoint("localhost:" + localport) + if err != nil { + return nil, err + } - server := NewEndpoint(tunnel) + server, err := NewEndpoint(tunnel) + if err != nil { + return nil, err + } if server.Port == 0 { server.Port = 22 } + remoteEndpoint, err := NewEndpoint(destination) + if err != nil { + return nil, err + } sshTunnel := &SSHTunnel{ Config: &ssh.ClientConfig{ User: server.User, @@ -167,9 +177,9 @@ func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localp }, Local: localEndpoint, Server: server, - Remote: NewEndpoint(destination), + Remote: remoteEndpoint, close: make(chan interface{}), } - return sshTunnel + return sshTunnel, nil }