diff --git a/dialer.go b/dialer.go index a8e4c1893..de4419ff8 100644 --- a/dialer.go +++ b/dialer.go @@ -25,7 +25,7 @@ import ( "github.com/quickfixgo/quickfix/config" ) -func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) { +func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) { stdDialer := &net.Dialer{} if settings.HasSetting(config.SocketTimeout) { timeout, err := settings.DurationSetting(config.SocketTimeout) @@ -73,9 +73,23 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error } } - dialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, dialer) + var proxyDialer proxy.Dialer + + proxyDialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, stdDialer) + if err != nil { + return + } + + if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok { + dialer = contextDialer + } else { + err = fmt.Errorf("proxy does not support context dialer") + return + } + default: err = fmt.Errorf("unsupported proxy type %s", proxyType) } + return } diff --git a/initiator.go b/initiator.go index 8f7a76200..18451477e 100644 --- a/initiator.go +++ b/initiator.go @@ -17,6 +17,7 @@ package quickfix import ( "bufio" + "context" "crypto/tls" "strings" "sync" @@ -50,7 +51,7 @@ func (i *Initiator) Start() (err error) { return } - var dialer proxy.Dialer + var dialer proxy.ContextDialer if dialer, err = loadDialerConfig(settings); err != nil { return } @@ -142,7 +143,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo return true } -func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.Dialer) { +func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -162,6 +163,19 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di return } + ctx, cancel := context.WithCancel(context.Background()) + + // We start a goroutine in order to be able to cancel the dialer mid-connection + // on receiving a stop signal to stop the initiator. + go func() { + select { + case <-i.stopChan: + cancel() + case <-ctx.Done(): + return + } + }() + var disconnected chan interface{} var msgIn chan fixIn var msgOut chan []byte @@ -169,7 +183,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] session.log.OnEventf("Connecting to: %v", address) - netConn, err := dialer.Dial("tcp", address) + netConn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect @@ -208,6 +222,10 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di close(disconnected) }() + // This ensures we properly cleanup the goroutine and context used for + // dial cancelation after successful connection. + cancel() + select { case <-disconnected: case <-i.stopChan: @@ -215,6 +233,8 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } reconnect: + cancel() + connectionAttempt++ session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval) if !i.waitForReconnectInterval(session.ReconnectInterval) {