From 589ab4186afa30741a3655d64abcf4296767cf0d Mon Sep 17 00:00:00 2001 From: Aitor Perez Cedres Date: Mon, 6 May 2024 15:55:13 +0100 Subject: [PATCH] Fix URI stringer implementation The string method was not including possible query parameters, namely, TLS parameters. This fix brings our URI implementation in accordance with the standard libary net/url, which does include query parameters in their String() function. Signed-off-by: Aitor Perez Cedres --- tls_test.go | 18 ++++++++++++++++++ uri.go | 24 ++++++++++++++++++++++++ uri_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/tls_test.go b/tls_test.go index 44c5fb9..a3916dd 100644 --- a/tls_test.go +++ b/tls_test.go @@ -82,6 +82,24 @@ func startTLSServer(t *testing.T, cfg *tls.Config) tlsServer { return s } +func TestTlsConfigFromUriPushdownServerNameIndication(t *testing.T) { + uri := "amqps://user:pass@example.com:5671?server_name_indication=another-hostname.com" + parsedUri, err := ParseURI(uri) + if err != nil { + t.Fatalf("expected to parse URI successfully, got error: %s", err) + } + + tlsConf, err := tlsConfigFromURI(parsedUri) + if err != nil { + t.Fatalf("expected tlsConfigFromURI to succeed, got error: %s", err) + } + + const expectedServerName = "another-hostname.com" + if tlsConf.ServerName != expectedServerName { + t.Fatalf("expected tlsConf server name to equal Uri servername: want %s, got %s", expectedServerName, tlsConf.ServerName) + } +} + // Tests opening a connection of a TLS enabled socket server func TestTLSHandshake(t *testing.T) { srv := startTLSServer(t, tlsServerConfig(t)) diff --git a/uri.go b/uri.go index f1db5e8..ddc4b1a 100644 --- a/uri.go +++ b/uri.go @@ -226,5 +226,29 @@ func (uri URI) String() string { authority.Path = "/" } + if uri.CertFile != "" || uri.KeyFile != "" || uri.CACertFile != "" || uri.ServerName != "" { + rawQuery := strings.Builder{} + if uri.CertFile != "" { + rawQuery.WriteString("certfile=") + rawQuery.WriteString(uri.CertFile) + rawQuery.WriteRune('&') + } + if uri.KeyFile != "" { + rawQuery.WriteString("keyfile=") + rawQuery.WriteString(uri.KeyFile) + rawQuery.WriteRune('&') + } + if uri.CACertFile != "" { + rawQuery.WriteString("cacertfile=") + rawQuery.WriteString(uri.CACertFile) + rawQuery.WriteRune('&') + } + if uri.ServerName != "" { + rawQuery.WriteString("server_name_indication=") + rawQuery.WriteString(uri.ServerName) + } + authority.RawQuery = rawQuery.String() + } + return authority.String() } diff --git a/uri_test.go b/uri_test.go index ca5424d..3139911 100644 --- a/uri_test.go +++ b/uri_test.go @@ -413,3 +413,31 @@ func TestURIParameters(t *testing.T) { t.Fatal("ChannelMax name not set") } } + +func TestURI_ParseUriToString(t *testing.T) { + tests := []struct { + name string + uri string + want string + }{ + {name: "virtual host is set", uri: "amqp://example.com/foobar", want: "amqp://example.com/foobar"}, + {name: "non-default port", uri: "amqp://foo.bar:1234/example", want: "amqp://foo.bar:1234/example"}, + { + name: "TLS with URI parameters", + uri: "amqps://some-host.com/foobar?certfile=/foo/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82/cert.pem&keyfile=/foo/%E4%BD%A0%E5%A5%BD/key.pem&cacertfile=C:%5Ccerts%5Cca.pem&server_name_indication=example.com", + want: "amqps://some-host.com/foobar?certfile=/foo/привет/cert.pem&keyfile=/foo/你好/key.pem&cacertfile=C:\\certs\\ca.pem&server_name_indication=example.com", + }, + {name: "only server name indication", uri: "amqps://foo.bar?server_name_indication=example.com", want: "amqps://foo.bar/?server_name_indication=example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + amqpUri, err := ParseURI(tt.uri) + if err != nil { + t.Errorf("ParseURI() error = %v", err) + } + if got := amqpUri.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +}