Skip to content

Commit 28ebb24

Browse files
vilius-glukebakken
authored andcommitted
Add support for additional AMQP URI query parameters
https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Add documentation for added query parameters Add support for additional AMQP URI query parameters https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Fix ChannelMax type mismatch Use URI heartbeat Bump versions on Windows
1 parent a2fcd5b commit 28ebb24

File tree

5 files changed

+122
-19
lines changed

5 files changed

+122
-19
lines changed

.ci/versions.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"erlang": "26.1.1",
3-
"rabbitmq": "3.12.6"
2+
"erlang": "26.2.2",
3+
"rabbitmq": "3.13.0"
44
}

connection.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (ne
157157
// scheme. It is equivalent to calling DialTLS(amqp, nil).
158158
func Dial(url string) (*Connection, error) {
159159
return DialConfig(url, Config{
160-
Heartbeat: defaultHeartbeat,
161-
Locale: defaultLocale,
160+
Locale: defaultLocale,
162161
})
163162
}
164163

@@ -169,7 +168,6 @@ func Dial(url string) (*Connection, error) {
169168
// DialTLS uses the provided tls.Config when encountering an amqps:// scheme.
170169
func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
171170
return DialConfig(url, Config{
172-
Heartbeat: defaultHeartbeat,
173171
TLSClientConfig: amqps,
174172
Locale: defaultLocale,
175173
})
@@ -186,7 +184,6 @@ func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
186184
// amqps:// scheme.
187185
func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
188186
return DialConfig(url, Config{
189-
Heartbeat: defaultHeartbeat,
190187
TLSClientConfig: amqps,
191188
SASL: []Authentication{&ExternalAuth{}},
192189
})
@@ -195,7 +192,9 @@ func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
195192
// DialConfig accepts a string in the AMQP URI format and a configuration for
196193
// the transport and connection setup, returning a new Connection. Defaults to
197194
// a server heartbeat interval of 10 seconds and sets the initial read deadline
198-
// to 30 seconds.
195+
// to 30 seconds. The heartbeat interval specified in the AMQP URI takes precedence
196+
// over the value specified in the config. To disable heartbeats, you must use
197+
// the AMQP URI and set heartbeat=0 there.
199198
func DialConfig(url string, config Config) (*Connection, error) {
200199
var err error
201200
var conn net.Conn
@@ -206,18 +205,50 @@ func DialConfig(url string, config Config) (*Connection, error) {
206205
}
207206

208207
if config.SASL == nil {
209-
config.SASL = []Authentication{uri.PlainAuth()}
208+
if uri.AuthMechanism != nil {
209+
for _, identifier := range uri.AuthMechanism {
210+
switch strings.ToUpper(identifier) {
211+
case "PLAIN":
212+
config.SASL = append(config.SASL, uri.PlainAuth())
213+
case "AMQPLAIN":
214+
config.SASL = append(config.SASL, uri.AMQPlainAuth())
215+
case "EXTERNAL":
216+
config.SASL = append(config.SASL, &ExternalAuth{})
217+
default:
218+
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
219+
}
220+
}
221+
} else {
222+
config.SASL = []Authentication{uri.PlainAuth()}
223+
}
210224
}
211225

212226
if config.Vhost == "" {
213227
config.Vhost = uri.Vhost
214228
}
215229

230+
if uri.Heartbeat.hasValue {
231+
config.Heartbeat = uri.Heartbeat.value
232+
} else {
233+
if config.Heartbeat == 0 {
234+
config.Heartbeat = defaultHeartbeat
235+
}
236+
}
237+
238+
if config.ChannelMax == 0 {
239+
config.ChannelMax = uri.ChannelMax
240+
}
241+
242+
connectionTimeout := defaultConnectionTimeout
243+
if uri.ConnectionTimeout != 0 {
244+
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
245+
}
246+
216247
addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))
217248

218249
dialer := config.Dial
219250
if dialer == nil {
220-
dialer = DefaultDial(defaultConnectionTimeout)
251+
dialer = DefaultDial(connectionTimeout)
221252
}
222253

223254
conn, err = dialer("tcp", addr)

types.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,16 @@ type bodyFrame struct {
553553
}
554554

555555
func (f *bodyFrame) channel() uint16 { return f.ChannelId }
556+
557+
type heartbeatDuration struct {
558+
value time.Duration
559+
hasValue bool
560+
}
561+
562+
func newHeartbeatDurationFromSeconds(s int) heartbeatDuration {
563+
v := time.Duration(s) * time.Second
564+
return heartbeatDuration{
565+
value: v,
566+
hasValue: true,
567+
}
568+
}

uri.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package amqp091
77

88
import (
99
"errors"
10+
"fmt"
1011
"net"
1112
"net/url"
1213
"strconv"
@@ -32,16 +33,20 @@ var defaultURI = URI{
3233

3334
// URI represents a parsed AMQP URI string.
3435
type URI struct {
35-
Scheme string
36-
Host string
37-
Port int
38-
Username string
39-
Password string
40-
Vhost string
41-
CertFile string // client TLS auth - path to certificate (PEM)
42-
CACertFile string // client TLS auth - path to CA certificate (PEM)
43-
KeyFile string // client TLS auth - path to private key (PEM)
44-
ServerName string // client TLS auth - server name
36+
Scheme string
37+
Host string
38+
Port int
39+
Username string
40+
Password string
41+
Vhost string
42+
CertFile string // client TLS auth - path to certificate (PEM)
43+
CACertFile string // client TLS auth - path to CA certificate (PEM)
44+
KeyFile string // client TLS auth - path to private key (PEM)
45+
ServerName string // client TLS auth - server name
46+
AuthMechanism []string
47+
Heartbeat heartbeatDuration
48+
ConnectionTimeout int
49+
ChannelMax uint16
4550
}
4651

4752
// ParseURI attempts to parse the given AMQP URI according to the spec.
@@ -62,6 +67,10 @@ type URI struct {
6267
// keyfile: <path/to/client_key.pem>
6368
// cacertfile: <path/to/ca.pem>
6469
// server_name_indication: <server name>
70+
// auth_mechanism: <one or more: plain, amqplain, external>
71+
// heartbeat: <seconds (integer)>
72+
// connection_timeout: <milliseconds (integer)>
73+
// channel_max: <max number of channels (integer)>
6574
//
6675
// If cacertfile is not provided, system CA certificates will be used.
6776
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
@@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) {
134143
builder.KeyFile = params.Get("keyfile")
135144
builder.CACertFile = params.Get("cacertfile")
136145
builder.ServerName = params.Get("server_name_indication")
146+
builder.AuthMechanism = params["auth_mechanism"]
147+
148+
if params.Has("heartbeat") {
149+
value, err := strconv.Atoi(params.Get("heartbeat"))
150+
if err != nil {
151+
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
152+
}
153+
builder.Heartbeat = newHeartbeatDurationFromSeconds(value)
154+
}
155+
156+
if params.Has("connection_timeout") {
157+
value, err := strconv.Atoi(params.Get("connection_timeout"))
158+
if err != nil {
159+
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
160+
}
161+
builder.ConnectionTimeout = value
162+
}
163+
164+
if params.Has("channel_max") {
165+
value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16)
166+
if err != nil {
167+
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
168+
}
169+
builder.ChannelMax = uint16(value)
170+
}
137171

138172
return builder, nil
139173
}

uri_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package amqp091
77

88
import (
9+
"reflect"
910
"testing"
11+
"time"
1012
)
1113

1214
// Test matrix defined on http://www.rabbitmq.com/uri-spec.html
@@ -388,3 +390,26 @@ func TestURITLSConfig(t *testing.T) {
388390
t.Fatal("Server name not set")
389391
}
390392
}
393+
394+
func TestURIParameters(t *testing.T) {
395+
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
396+
uri, err := ParseURI(url)
397+
if err != nil {
398+
t.Fatal("Could not parse")
399+
}
400+
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
401+
t.Fatal("AuthMechanism not set")
402+
}
403+
if !uri.Heartbeat.hasValue {
404+
t.Fatal("Heartbeat not set")
405+
}
406+
if uri.Heartbeat.value != time.Duration(2)*time.Second {
407+
t.Fatal("Heartbeat not set")
408+
}
409+
if uri.ConnectionTimeout != 5000 {
410+
t.Fatal("ConnectionTimeout not set")
411+
}
412+
if uri.ChannelMax != 8 {
413+
t.Fatal("ChannelMax name not set")
414+
}
415+
}

0 commit comments

Comments
 (0)