Skip to content

Commit

Permalink
refactor: simplify host normalization and enhance hostname validation…
Browse files Browse the repository at this point in the history
… functions
  • Loading branch information
root4loot committed Oct 3, 2024
1 parent 8ab888e commit 7fc2455
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 110 deletions.
47 changes: 0 additions & 47 deletions hostutil/hostutil.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,9 @@
package hostutil

import (
"fmt"
"net"
"strconv"
"strings"
)

// NormalizeHost takes a host input (either a domain name or hostname) and returns it in a standardized format.
// It converts the input to lowercase, removes redundant ports (e.g., 443 for HTTPS and 80 for HTTP),
// validates that the input is a valid fully qualified domain name (FQDN), and ensures that the hostname is properly formatted.
func NormalizeHost(host string) (string, error) {
host = strings.TrimSpace(strings.ToLower(host))

if strings.Contains(host, "://") || strings.Contains(host, "/") {
return "", fmt.Errorf("input contains URL scheme or path: %s", host)
}

hostname, port, err := net.SplitHostPort(host)
if err != nil {
if strings.Contains(host, ":") {
return "", fmt.Errorf("failed to parse host input: %s", host)
}
hostname = host
}

if port != "" {
if _, err := strconv.Atoi(port); err != nil {
return "", fmt.Errorf("invalid port number: %s", port)
}
}

if net.ParseIP(hostname) == nil && !IsValidHostname(hostname) {
return "", fmt.Errorf("invalid hostname or IP address: %s", hostname)
}

if !strings.Contains(hostname, ".") {
return "", fmt.Errorf("invalid FQDN: %s", hostname)
}

switch port {
case "443", "80":
port = ""
}

if port != "" {
hostname = net.JoinHostPort(hostname, port)
}

return hostname, nil
}

// IsValidHostname checks if the given hostname is valid based on RFC 1123.
func IsValidHostname(hostname string) bool {
if len(hostname) == 0 || len(hostname) > 255 {
Expand Down
64 changes: 36 additions & 28 deletions hostutil/hostutil_test.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,50 @@
package hostutil

import (
"strings"
"testing"
)

func TestNormalizeHost(t *testing.T) {
func TestIsValidHostname(t *testing.T) {
tests := []struct {
input string
expected string
hasError bool
hostname string
valid bool
}{
{"example.com", "example.com", false},
{"Example.COM", "example.com", false},
{"example.com:8080", "example.com:8080", false},
{"example.com:80", "example.com", false},
{"example.com:443", "example.com", false},
{"subdomain.example.com:443", "subdomain.example.com", false},
{"subdomain.example.com:80", "subdomain.example.com", false},
{" example.com ", "example.com", false},
{"invalid_host:port", "", true},
{"subdomain:invalidport", "", true},
{"http://example.com", "", true},
{"https://example.com", "", true},
{"ftp://example.com", "", true},
{"example.com/path", "", true},
{"http://example.com:443", "", true},
{"example", "", true},
{"localhost", "", true},
{"example.com", true},
{"localhost", true},
{"sub.domain.example.com", true},
{"example", true},
{"example123.com", true},
{"123example.com", true},
{"example-com", true},
{"example.com-", false},
{"-example.com", false},
{"exa_mple.com", false},
{"example..com", false},
{"", false},
{strings.Repeat("a", 256), false},
{"example!.com", false},
{"example .com", false},
{".example.com", false},
{"example.com.", false},
{"ex%ample.com", false},
{"example.com/", false},
{"example..com", false},
{"-example-.com", false},
{"ex--ample.com", true},
{"example.-com", false},
{"example.com-", false},
{"example-.com", false},
{"exa*mple.com", false},
{"example@com", false},
{"example,com", false},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, err := NormalizeHost(tt.input)
if (err != nil) != tt.hasError {
t.Errorf("expected error status %v, got %v (error: %v)", tt.hasError, (err != nil), err)
}
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
t.Run(tt.hostname, func(t *testing.T) {
result := IsValidHostname(tt.hostname)
if result != tt.valid {
t.Errorf("IsValidHostname(%q) = %v; want %v", tt.hostname, result, tt.valid)
}
})
}
Expand Down
29 changes: 11 additions & 18 deletions urlutil/urlutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,20 @@ func HasScheme(rawURL string) bool {
return re.MatchString(rawURL)
}

// EnsureScheme ensures a URL has a scheme. If no scheme is provided, it defaults to "http".
func EnsureScheme(rawURL string, scheme ...string) string {
if rawURL == "" {
return rawURL
}

defaultScheme := "http"
if len(scheme) > 0 && scheme[0] != "" {
defaultScheme = scheme[0]
}

u, err := url.Parse(rawURL)
if err != nil {
return defaultScheme + "://" + rawURL
// EnsureHTTP ensures a URL has an HTTP scheme
func EnsureHTTP(rawURL string) string {
if !HasScheme(rawURL) {
rawURL = "http://" + rawURL
}
return rawURL
}

if u.Scheme == "" {
u.Scheme = defaultScheme
// EnsureHTTPS ensures a URL has an HTTPS scheme
func EnsureHTTPS(rawURL string) string {
if !HasScheme(rawURL) {
rawURL = "https://" + rawURL
}

return u.String()
return rawURL
}

// HasFileExtension checks if the given rawURL string has a file extension in its path
Expand Down
99 changes: 82 additions & 17 deletions urlutil/urlutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,6 @@ func TestHasScheme(t *testing.T) {
}
}

func TestEnsureScheme(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"example.com", "http://example.com"},
{"http://example.com", "http://example.com"},
}

for _, test := range tests {
result := EnsureScheme(test.input)
if result != test.expected {
t.Errorf("EnsureScheme(%s) = %s; want %s", test.input, result, test.expected)
}
}
}

func TestHasFileExtension(t *testing.T) {
tests := []struct {
input string
Expand Down Expand Up @@ -228,3 +211,85 @@ func TestIsMediaExt(t *testing.T) {
}
}
}

func TestRemoveDefaultPort(t *testing.T) {
tests := []struct {
input string
expected string
hasError bool
}{
{"http://example.com:80", "http://example.com", false},
{"https://example.com:443", "https://example.com", false},
{"http://example.com:8080", "http://example.com:8080", false},
{"https://example.com:8443", "https://example.com:8443", false},
{"ftp://example.com:21", "ftp://example.com", false},
{"ftp://example.com:2121", "ftp://example.com:2121", false},
{"http://example.com", "http://example.com", false},
{"https://example.com", "https://example.com", false},
{"invalid_url", "", true},
{"example.com", "", true},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, err := RemoveDefaultPort(tt.input)
if (err != nil) != tt.hasError {
t.Errorf("expected error status %v, got %v (error: %v)", tt.hasError, (err != nil), err)
}
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestEnsureHTTP(t *testing.T) {
tests := []struct {
rawURL string
expected string
}{
{"example.com", "http://example.com"},
{"http://example.com", "http://example.com"},
{"https://example.com", "https://example.com"},
{"example.com/path", "http://example.com/path"},
{"localhost", "http://localhost"},
{"http://localhost", "http://localhost"},
{"https://localhost", "https://localhost"},
{"192.168.0.1", "http://192.168.0.1"},
{"http://192.168.0.1", "http://192.168.0.1"},
}

for _, tt := range tests {
t.Run(tt.rawURL, func(t *testing.T) {
result := EnsureHTTP(tt.rawURL)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestEnsureHTTPS(t *testing.T) {
tests := []struct {
rawURL string
expected string
}{
{"example.com", "https://example.com"},
{"http://example.com", "http://example.com"},
{"https://example.com", "https://example.com"},
{"example.com/path", "https://example.com/path"},
{"localhost", "https://localhost"},
{"https://localhost", "https://localhost"},
{"192.168.0.1", "https://192.168.0.1"},
{"https://192.168.0.1", "https://192.168.0.1"},
}

for _, tt := range tests {
t.Run(tt.rawURL, func(t *testing.T) {
result := EnsureHTTPS(tt.rawURL)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

0 comments on commit 7fc2455

Please sign in to comment.