Skip to content

Commit 27700fc

Browse files
Add connection retry with MaxConnectionAttempts (#16)
The application using this library hangs (never exits) if the SSH tunnel is being used by a lots of goroutines and an error occurs in the forward() method when a connection is being made. Connection attempts seem to intermittently fail, and this somehow leads to the code never exiting when complete. I found that the connection attempt would succeed after 1 or 2 retries, then the application would later exit like normal. This PR adds a optional retry mechanism. It must be enabled with MaxConnectionAttempts. Fixed #15
1 parent 6539d4e commit 27700fc

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

ssh_tunnel.go

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ type logger interface {
1212
}
1313

1414
type SSHTunnel struct {
15-
Local *Endpoint
16-
Server *Endpoint
17-
Remote *Endpoint
18-
Config *ssh.ClientConfig
19-
Log logger
20-
Conns []net.Conn
21-
SvrConns []*ssh.Client
22-
isOpen bool
23-
close chan interface{}
15+
Local *Endpoint
16+
Server *Endpoint
17+
Remote *Endpoint
18+
Config *ssh.ClientConfig
19+
Log logger
20+
Conns []net.Conn
21+
SvrConns []*ssh.Client
22+
MaxConnectionAttempts int
23+
isOpen bool
24+
close chan interface{}
2425
}
2526

2627
func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
@@ -45,6 +46,14 @@ func (tunnel *SSHTunnel) Start() error {
4546
tunnel.isOpen = true
4647
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
4748

49+
// Ensure that MaxConnectionAttempts is at least 1. This check is done here
50+
// since the library user can set the value at any point before Start() is called,
51+
// and this check protects against the case where the programmer set MaxConnectionAttempts
52+
// to 0 for some reason.
53+
if tunnel.MaxConnectionAttempts <= 0 {
54+
tunnel.MaxConnectionAttempts = 1
55+
}
56+
4857
for {
4958
if !tunnel.isOpen {
5059
break
@@ -90,14 +99,29 @@ func (tunnel *SSHTunnel) Start() error {
9099
}
91100

92101
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
93-
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
94-
if err != nil {
95-
tunnel.logf("server dial error: %s", err)
96-
return
102+
var (
103+
serverConn *ssh.Client
104+
err error
105+
attemptsLeft int = tunnel.MaxConnectionAttempts
106+
)
107+
108+
for {
109+
serverConn, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
110+
if err != nil {
111+
attemptsLeft--
112+
113+
if attemptsLeft <= 0 {
114+
tunnel.logf("server dial error: %v: exceeded %d attempts", err, tunnel.MaxConnectionAttempts)
115+
return
116+
}
117+
} else {
118+
break
119+
}
97120
}
121+
98122
tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())
99123
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)
100-
124+
101125
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
102126
if err != nil {
103127
tunnel.logf("remote dial error: %s", err)

0 commit comments

Comments
 (0)