5
5
"io"
6
6
"log"
7
7
"net"
8
+ "sync"
8
9
)
9
10
10
11
type SSHTunnel struct {
@@ -13,6 +14,7 @@ type SSHTunnel struct {
13
14
Remote * Endpoint
14
15
Config * ssh.ClientConfig
15
16
Log * log.Logger
17
+ close chan interface {}
16
18
}
17
19
18
20
func (tunnel * SSHTunnel ) logf (fmt string , args ... interface {}) {
@@ -26,47 +28,59 @@ func (tunnel *SSHTunnel) Start() error {
26
28
if err != nil {
27
29
return err
28
30
}
29
- defer listener .Close ()
30
-
31
31
tunnel .Local .Port = listener .Addr ().(* net.TCPAddr ).Port
32
-
33
32
for {
34
33
conn , err := listener .Accept ()
35
34
if err != nil {
36
35
return err
37
36
}
38
-
39
37
tunnel .logf ("accepted connection" )
40
- go tunnel .forward (conn )
38
+ var wg sync.WaitGroup
39
+ go tunnel .forward (conn , & wg )
40
+ wg .Wait ()
41
+ tunnel .logf ("tunnel closed" )
42
+ break
41
43
}
44
+ err = listener .Close ()
45
+ if err != nil {
46
+ return err
47
+ }
48
+ return nil
42
49
}
43
50
44
- func (tunnel * SSHTunnel ) forward (localConn net.Conn ) {
51
+ func (tunnel * SSHTunnel ) forward (localConn net.Conn , wg * sync. WaitGroup ) {
45
52
serverConn , err := ssh .Dial ("tcp" , tunnel .Server .String (), tunnel .Config )
46
53
if err != nil {
47
54
tunnel .logf ("server dial error: %s" , err )
48
55
return
49
56
}
50
-
51
57
tunnel .logf ("connected to %s (1 of 2)\n " , tunnel .Server .String ())
52
-
53
58
remoteConn , err := serverConn .Dial ("tcp" , tunnel .Remote .String ())
54
59
if err != nil {
55
60
tunnel .logf ("remote dial error: %s" , err )
56
61
return
57
62
}
58
-
59
63
tunnel .logf ("connected to %s (2 of 2)\n " , tunnel .Remote .String ())
60
-
61
64
copyConn := func (writer , reader net.Conn ) {
62
65
_ , err := io .Copy (writer , reader )
63
66
if err != nil {
64
67
tunnel .logf ("io.Copy error: %s" , err )
65
68
}
66
69
}
67
-
68
70
go copyConn (localConn , remoteConn )
69
71
go copyConn (remoteConn , localConn )
72
+ <- tunnel .close
73
+ tunnel .logf ("close signal received, closing..." )
74
+ _ = localConn .Close ()
75
+ _ = serverConn .Close ()
76
+ _ = remoteConn .Close ()
77
+ wg .Done ()
78
+ return
79
+ }
80
+
81
+ func (tunnel * SSHTunnel ) Close () {
82
+ tunnel .close <- struct {}{}
83
+ return
70
84
}
71
85
72
86
func NewSSHTunnel (tunnel string , auth ssh.AuthMethod , destination string ) * SSHTunnel {
@@ -90,6 +104,7 @@ func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string) *SSHTu
90
104
Local : localEndpoint ,
91
105
Server : server ,
92
106
Remote : NewEndpoint (destination ),
107
+ close : make (chan interface {}),
93
108
}
94
109
95
110
return sshTunnel
0 commit comments