Skip to content

Commit a5f5bfd

Browse files
authored
Merge branch 'quickfixgo:main' into dynamic_session_2
2 parents 80a93d4 + 9191a58 commit a5f5bfd

11 files changed

+459
-81
lines changed

config/configuration.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,47 @@ const (
742742
// - A filepath to a file with read access.
743743
SocketCAFile string = "SocketCAFile"
744744

745+
// SocketPrivateKeyBytes is an optional value containing raw bytes of a PEM
746+
// encoded private key to use for secure TLS communications.
747+
// Must be used with SocketCertificateBytes.
748+
// Must contain PEM encoded data.
749+
//
750+
// Required: No
751+
//
752+
// Default: N/A
753+
//
754+
// Valid Values:
755+
// - Raw bytes containing a valid PEM encoded private key.
756+
SocketPrivateKeyBytes string = "SocketPrivateKeyBytes"
757+
758+
// SocketCertificateBytes is an optional value containing raw bytes of a PEM
759+
// encoded certificate to use for secure TLS communications.
760+
// Must be used with SocketPrivateKeyBytes.
761+
// Must contain PEM encoded data.
762+
//
763+
// Required: No
764+
//
765+
// Default: N/A
766+
//
767+
// Valid Values:
768+
// - Raw bytes containing a valid PEM encoded certificate.
769+
SocketCertificateBytes string = "SocketCertificateBytes"
770+
771+
// SocketCABytes is an optional value containing raw bytes of a PEM encoded
772+
// root CA to use for secure TLS communications. For acceptors, client
773+
// certificates will be verified against this CA. For initiators, clients
774+
// will use the CA to verify the server certificate. If not configured,
775+
// initiators will verify the server certificates using the host's root CA
776+
// set.
777+
//
778+
// Required: No
779+
//
780+
// Default: N/A
781+
//
782+
// Valid Values:
783+
// - Raw bytes containing a valid PEM encoded CA.
784+
SocketCABytes string = "SocketCABytes"
785+
745786
// SocketInsecureSkipVerify controls whether a client verifies the server's certificate chain and host name.
746787
// If SocketInsecureSkipVerify is set to Y, crypto/tls accepts any certificate presented by the server and any host name in that certificate.
747788
// In this mode, TLS is susceptible to machine-in-the-middle attacks unless custom verification is used.

dialer.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"github.com/quickfixgo/quickfix/config"
2626
)
2727

28-
func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) {
28+
func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) {
2929
stdDialer := &net.Dialer{}
3030
if settings.HasSetting(config.SocketTimeout) {
3131
timeout, err := settings.DurationSetting(config.SocketTimeout)
@@ -73,9 +73,23 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error
7373
}
7474
}
7575

76-
dialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, dialer)
76+
var proxyDialer proxy.Dialer
77+
78+
proxyDialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, stdDialer)
79+
if err != nil {
80+
return
81+
}
82+
83+
if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok {
84+
dialer = contextDialer
85+
} else {
86+
err = fmt.Errorf("proxy does not support context dialer")
87+
return
88+
}
89+
7790
default:
7891
err = fmt.Errorf("unsupported proxy type %s", proxyType)
7992
}
93+
8094
return
8195
}

field_map.go

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,20 @@ func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError
115115
return nil
116116
}
117117

118+
// GetField parses of a field with Tag tag. Returned reject may indicate the field is not present, or the field value is invalid.
119+
func (m FieldMap) getFieldNoLock(tag Tag, parser FieldValueReader) MessageRejectError {
120+
f, ok := m.tagLookup[tag]
121+
if !ok {
122+
return ConditionallyRequiredFieldMissing(tag)
123+
}
124+
125+
if err := parser.Read(f[0].value); err != nil {
126+
return IncorrectDataFormatForValue(tag)
127+
}
128+
129+
return nil
130+
}
131+
118132
// GetBytes is a zero-copy GetField wrapper for []bytes fields.
119133
func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) {
120134
m.rwLock.RLock()
@@ -128,6 +142,16 @@ func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) {
128142
return f[0].value, nil
129143
}
130144

145+
// getBytesNoLock is a lock free zero-copy GetField wrapper for []bytes fields.
146+
func (m FieldMap) getBytesNoLock(tag Tag) ([]byte, MessageRejectError) {
147+
f, ok := m.tagLookup[tag]
148+
if !ok {
149+
return nil, ConditionallyRequiredFieldMissing(tag)
150+
}
151+
152+
return f[0].value, nil
153+
}
154+
131155
// GetBool is a GetField wrapper for bool fields.
132156
func (m FieldMap) GetBool(tag Tag) (bool, MessageRejectError) {
133157
var val FIXBoolean
@@ -152,6 +176,21 @@ func (m FieldMap) GetInt(tag Tag) (int, MessageRejectError) {
152176
return int(val), err
153177
}
154178

179+
// GetInt is a lock free GetField wrapper for int fields.
180+
func (m FieldMap) getIntNoLock(tag Tag) (int, MessageRejectError) {
181+
bytes, err := m.getBytesNoLock(tag)
182+
if err != nil {
183+
return 0, err
184+
}
185+
186+
var val FIXInt
187+
if val.Read(bytes) != nil {
188+
err = IncorrectDataFormatForValue(tag)
189+
}
190+
191+
return int(val), err
192+
}
193+
155194
// GetTime is a GetField wrapper for utc timestamp fields.
156195
func (m FieldMap) GetTime(tag Tag) (t time.Time, err MessageRejectError) {
157196
m.rwLock.RLock()
@@ -179,6 +218,15 @@ func (m FieldMap) GetString(tag Tag) (string, MessageRejectError) {
179218
return string(val), nil
180219
}
181220

221+
// GetString is a GetField wrapper for string fields.
222+
func (m FieldMap) getStringNoLock(tag Tag) (string, MessageRejectError) {
223+
var val FIXString
224+
if err := m.getFieldNoLock(tag, &val); err != nil {
225+
return "", err
226+
}
227+
return string(val), nil
228+
}
229+
182230
// GetGroup is a Get function specific to Group Fields.
183231
func (m FieldMap) GetGroup(parser FieldGroupReader) MessageRejectError {
184232
m.rwLock.RLock()
@@ -246,6 +294,13 @@ func (m *FieldMap) Clear() {
246294
}
247295
}
248296

297+
func (m *FieldMap) clearNoLock() {
298+
m.tags = m.tags[0:0]
299+
for k := range m.tagLookup {
300+
delete(m.tagLookup, k)
301+
}
302+
}
303+
249304
// CopyInto overwrites the given FieldMap with this one.
250305
func (m *FieldMap) CopyInto(to *FieldMap) {
251306
m.rwLock.RLock()
@@ -263,9 +318,6 @@ func (m *FieldMap) CopyInto(to *FieldMap) {
263318
}
264319

265320
func (m *FieldMap) add(f field) {
266-
m.rwLock.Lock()
267-
defer m.rwLock.Unlock()
268-
269321
t := fieldTag(f)
270322
if _, ok := m.tagLookup[t]; !ok {
271323
m.tags = append(m.tags, t)

initiator.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package quickfix
1717

1818
import (
1919
"bufio"
20+
"context"
2021
"crypto/tls"
2122
"strings"
2223
"sync"
@@ -50,7 +51,7 @@ func (i *Initiator) Start() (err error) {
5051
return
5152
}
5253

53-
var dialer proxy.Dialer
54+
var dialer proxy.ContextDialer
5455
if dialer, err = loadDialerConfig(settings); err != nil {
5556
return
5657
}
@@ -142,7 +143,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo
142143
return true
143144
}
144145

145-
func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.Dialer) {
146+
func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) {
146147
var wg sync.WaitGroup
147148
wg.Add(1)
148149
go func() {
@@ -162,14 +163,27 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
162163
return
163164
}
164165

166+
ctx, cancel := context.WithCancel(context.Background())
167+
168+
// We start a goroutine in order to be able to cancel the dialer mid-connection
169+
// on receiving a stop signal to stop the initiator.
170+
go func() {
171+
select {
172+
case <-i.stopChan:
173+
cancel()
174+
case <-ctx.Done():
175+
return
176+
}
177+
}()
178+
165179
var disconnected chan interface{}
166180
var msgIn chan fixIn
167181
var msgOut chan []byte
168182

169183
address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
170184
session.log.OnEventf("Connecting to: %v", address)
171185

172-
netConn, err := dialer.Dial("tcp", address)
186+
netConn, err := dialer.DialContext(ctx, "tcp", address)
173187
if err != nil {
174188
session.log.OnEventf("Failed to connect: %v", err)
175189
goto reconnect
@@ -208,13 +222,19 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
208222
close(disconnected)
209223
}()
210224

225+
// This ensures we properly cleanup the goroutine and context used for
226+
// dial cancelation after successful connection.
227+
cancel()
228+
211229
select {
212230
case <-disconnected:
213231
case <-i.stopChan:
214232
return
215233
}
216234

217235
reconnect:
236+
cancel()
237+
218238
connectionAttempt++
219239
session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval)
220240
if !i.waitForReconnectInterval(session.ReconnectInterval) {

message.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,20 @@ func ParseMessageWithDataDictionary(
181181

182182
// doParsing executes the message parsing process.
183183
func doParsing(mp *msgParser) (err error) {
184+
mp.msg.Header.rwLock.Lock()
185+
defer mp.msg.Header.rwLock.Unlock()
186+
mp.msg.Body.rwLock.Lock()
187+
defer mp.msg.Body.rwLock.Unlock()
188+
mp.msg.Trailer.rwLock.Lock()
189+
defer mp.msg.Trailer.rwLock.Unlock()
190+
184191
// Initialize for parsing.
185-
mp.msg.Header.Clear()
186-
mp.msg.Body.Clear()
187-
mp.msg.Trailer.Clear()
192+
mp.msg.Header.clearNoLock()
193+
mp.msg.Body.clearNoLock()
194+
mp.msg.Trailer.clearNoLock()
188195

189196
// Allocate expected message fields in one chunk.
190-
fieldCount := 0
191-
for _, b := range mp.rawBytes {
192-
if b == '\001' {
193-
fieldCount++
194-
}
195-
}
197+
fieldCount := bytes.Count(mp.rawBytes, []byte{'\001'})
196198
if fieldCount == 0 {
197199
return parseError{OrigError: fmt.Sprintf("No Fields detected in %s", string(mp.rawBytes))}
198200
}
@@ -267,7 +269,7 @@ func doParsing(mp *msgParser) (err error) {
267269
}
268270

269271
if mp.parsedFieldBytes.tag == tagXMLDataLen {
270-
xmlDataLen, _ = mp.msg.Header.GetInt(tagXMLDataLen)
272+
xmlDataLen, _ = mp.msg.Header.getIntNoLock(tagXMLDataLen)
271273
}
272274
mp.fieldIndex++
273275
}
@@ -292,7 +294,7 @@ func doParsing(mp *msgParser) (err error) {
292294
}
293295
}
294296

295-
bodyLength, err := mp.msg.Header.GetInt(tagBodyLength)
297+
bodyLength, err := mp.msg.Header.getIntNoLock(tagBodyLength)
296298
if err != nil {
297299
err = parseError{OrigError: err.Error()}
298300
} else if length != bodyLength && !xmlDataMsg {
@@ -373,7 +375,7 @@ func parseGroup(mp *msgParser, tags []Tag) {
373375
// tags slice will contain multiple tags if the tag in question is found while processing a group already.
374376
func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) bool {
375377
if appDataDictionary != nil {
376-
msgt, err := msg.MsgType()
378+
msgt, err := msg.msgTypeNoLock()
377379
if err != nil {
378380
return false
379381
}
@@ -406,7 +408,7 @@ func isNumInGroupField(msg *Message, tags []Tag, appDataDictionary *datadictiona
406408
// tags slice will contain multiple tags if the tag in question is found while processing a group already.
407409
func getGroupFields(msg *Message, tags []Tag, appDataDictionary *datadictionary.DataDictionary) (fields []*datadictionary.FieldDef) {
408410
if appDataDictionary != nil {
409-
msgt, err := msg.MsgType()
411+
msgt, err := msg.msgTypeNoLock()
410412
if err != nil {
411413
return
412414
}
@@ -476,6 +478,10 @@ func (m *Message) MsgType() (string, MessageRejectError) {
476478
return m.Header.GetString(tagMsgType)
477479
}
478480

481+
func (m *Message) msgTypeNoLock() (string, MessageRejectError) {
482+
return m.Header.getStringNoLock(tagMsgType)
483+
}
484+
479485
// IsMsgTypeOf returns true if the Header contains MsgType (tag 35) field and its value is the specified one.
480486
func (m *Message) IsMsgTypeOf(msgType string) bool {
481487
if v, err := m.MsgType(); err == nil {

session_factory.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func (f sessionFactory) newSession(
284284
for _, dayStr := range dayStrs {
285285
day, ok := dayLookup[dayStr]
286286
if !ok {
287-
err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: weekdaysStr}
287+
err = IncorrectFormatForSetting{Setting: config.Weekdays, Value: []byte(weekdaysStr)}
288288
return
289289
}
290290
weekdays = append(weekdays, day)
@@ -315,7 +315,7 @@ func (f sessionFactory) newSession(
315315
parseDay := func(setting, dayStr string) (day time.Weekday, err error) {
316316
day, ok := dayLookup[dayStr]
317317
if !ok {
318-
return day, IncorrectFormatForSetting{Setting: setting, Value: dayStr}
318+
return day, IncorrectFormatForSetting{Setting: setting, Value: []byte(dayStr)}
319319
}
320320
return
321321
}
@@ -355,7 +355,7 @@ func (f sessionFactory) newSession(
355355
s.timestampPrecision = Nanos
356356

357357
default:
358-
err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: precisionStr}
358+
err = IncorrectFormatForSetting{Setting: config.TimeStampPrecision, Value: []byte(precisionStr)}
359359
return
360360
}
361361
}

0 commit comments

Comments
 (0)