From 70d14d8ececa3675d8ec7c00fcc70c5e1b1465e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilius=20Grigali=C5=ABnas?= Date: Thu, 29 Feb 2024 14:23:41 +0200 Subject: [PATCH] Add support for additional AMQP URI query parameters https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Add documentation for added query parameters Add support for additional AMQP URI query parameters https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Fix ChannelMax type mismatch Use URI heartbeat Bump versions on Windows --- .ci/versions.json | 4 ++-- connection.go | 41 +++++++++++++++++++++++++++++------ types.go | 13 ++++++++++++ uri.go | 54 ++++++++++++++++++++++++++++++++++++++--------- uri_test.go | 24 +++++++++++++++++++++ 5 files changed, 118 insertions(+), 18 deletions(-) diff --git a/.ci/versions.json b/.ci/versions.json index 89002ac..231e0b5 100644 --- a/.ci/versions.json +++ b/.ci/versions.json @@ -1,4 +1,4 @@ { - "erlang": "26.1.1", - "rabbitmq": "3.12.6" + "erlang": "26.2.2", + "rabbitmq": "3.13.0" } diff --git a/connection.go b/connection.go index 0f3f6a4..9a70fdc 100644 --- a/connection.go +++ b/connection.go @@ -157,8 +157,7 @@ func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (ne // scheme. It is equivalent to calling DialTLS(amqp, nil). func Dial(url string) (*Connection, error) { return DialConfig(url, Config{ - Heartbeat: defaultHeartbeat, - Locale: defaultLocale, + Locale: defaultLocale, }) } @@ -169,7 +168,6 @@ func Dial(url string) (*Connection, error) { // DialTLS uses the provided tls.Config when encountering an amqps:// scheme. func DialTLS(url string, amqps *tls.Config) (*Connection, error) { return DialConfig(url, Config{ - Heartbeat: defaultHeartbeat, TLSClientConfig: amqps, Locale: defaultLocale, }) @@ -186,7 +184,6 @@ func DialTLS(url string, amqps *tls.Config) (*Connection, error) { // amqps:// scheme. func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) { return DialConfig(url, Config{ - Heartbeat: defaultHeartbeat, TLSClientConfig: amqps, SASL: []Authentication{&ExternalAuth{}}, }) @@ -206,18 +203,50 @@ func DialConfig(url string, config Config) (*Connection, error) { } if config.SASL == nil { - config.SASL = []Authentication{uri.PlainAuth()} + if uri.AuthMechanism != nil { + for _, identifier := range uri.AuthMechanism { + switch strings.ToUpper(identifier) { + case "PLAIN": + config.SASL = append(config.SASL, uri.PlainAuth()) + case "AMQPLAIN": + config.SASL = append(config.SASL, uri.AMQPlainAuth()) + case "EXTERNAL": + config.SASL = append(config.SASL, &ExternalAuth{}) + default: + return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier) + } + } + } else { + config.SASL = []Authentication{uri.PlainAuth()} + } } if config.Vhost == "" { config.Vhost = uri.Vhost } + if uri.Heartbeat.hasValue { + config.Heartbeat = uri.Heartbeat.value + } else { + if config.Heartbeat == 0 { + config.Heartbeat = defaultHeartbeat + } + } + + if config.ChannelMax == 0 { + config.ChannelMax = uri.ChannelMax + } + + connectionTimeout := defaultConnectionTimeout + if uri.ConnectionTimeout != 0 { + connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond + } + addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10)) dialer := config.Dial if dialer == nil { - dialer = DefaultDial(defaultConnectionTimeout) + dialer = DefaultDial(connectionTimeout) } conn, err = dialer("tcp", addr) diff --git a/types.go b/types.go index d7d8f26..1e15ed0 100644 --- a/types.go +++ b/types.go @@ -553,3 +553,16 @@ type bodyFrame struct { } func (f *bodyFrame) channel() uint16 { return f.ChannelId } + +type heartbeatDuration struct { + value time.Duration + hasValue bool +} + +func newHeartbeatDurationFromSeconds(s int) heartbeatDuration { + v := time.Duration(s) * time.Second + return heartbeatDuration{ + value: v, + hasValue: true, + } +} diff --git a/uri.go b/uri.go index 87ef09e..f1db5e8 100644 --- a/uri.go +++ b/uri.go @@ -7,6 +7,7 @@ package amqp091 import ( "errors" + "fmt" "net" "net/url" "strconv" @@ -32,16 +33,20 @@ var defaultURI = URI{ // URI represents a parsed AMQP URI string. type URI struct { - Scheme string - Host string - Port int - Username string - Password string - Vhost string - CertFile string // client TLS auth - path to certificate (PEM) - CACertFile string // client TLS auth - path to CA certificate (PEM) - KeyFile string // client TLS auth - path to private key (PEM) - ServerName string // client TLS auth - server name + Scheme string + Host string + Port int + Username string + Password string + Vhost string + CertFile string // client TLS auth - path to certificate (PEM) + CACertFile string // client TLS auth - path to CA certificate (PEM) + KeyFile string // client TLS auth - path to private key (PEM) + ServerName string // client TLS auth - server name + AuthMechanism []string + Heartbeat heartbeatDuration + ConnectionTimeout int + ChannelMax uint16 } // ParseURI attempts to parse the given AMQP URI according to the spec. @@ -62,6 +67,10 @@ type URI struct { // keyfile: // cacertfile: // server_name_indication: +// auth_mechanism: +// heartbeat: +// connection_timeout: +// channel_max: // // If cacertfile is not provided, system CA certificates will be used. // Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided. @@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) { builder.KeyFile = params.Get("keyfile") builder.CACertFile = params.Get("cacertfile") builder.ServerName = params.Get("server_name_indication") + builder.AuthMechanism = params["auth_mechanism"] + + if params.Has("heartbeat") { + value, err := strconv.Atoi(params.Get("heartbeat")) + if err != nil { + return builder, fmt.Errorf("heartbeat is not an integer: %v", err) + } + builder.Heartbeat = newHeartbeatDurationFromSeconds(value) + } + + if params.Has("connection_timeout") { + value, err := strconv.Atoi(params.Get("connection_timeout")) + if err != nil { + return builder, fmt.Errorf("connection_timeout is not an integer: %v", err) + } + builder.ConnectionTimeout = value + } + + if params.Has("channel_max") { + value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16) + if err != nil { + return builder, fmt.Errorf("connection_timeout is not an integer: %v", err) + } + builder.ChannelMax = uint16(value) + } return builder, nil } diff --git a/uri_test.go b/uri_test.go index a369441..eef5f50 100644 --- a/uri_test.go +++ b/uri_test.go @@ -6,6 +6,7 @@ package amqp091 import ( + "reflect" "testing" ) @@ -388,3 +389,26 @@ func TestURITLSConfig(t *testing.T) { t.Fatal("Server name not set") } } + +func TestURIParameters(t *testing.T) { + url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8" + uri, err := ParseURI(url) + if err != nil { + t.Fatal("Could not parse") + } + if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) { + t.Fatal("AuthMechanism not set") + } + if !uri.Heartbeat.hasValue { + t.Fatal("Heartbeat not set") + } + if uri.Heartbeat.value != 2 { + t.Fatal("Heartbeat not set") + } + if uri.ConnectionTimeout != 5000 { + t.Fatal("ConnectionTimeout not set") + } + if uri.ChannelMax != 8 { + t.Fatal("ChannelMax name not set") + } +}