From 849160d28563192a5e2d458dc26bded61c05866b Mon Sep 17 00:00:00 2001 From: Jonathan Ingram Date: Tue, 7 Jul 2020 18:49:43 +1000 Subject: [PATCH] add URL --- check_test.go | 36 ++++++++++++++++++++++++++++++++++++ env.go | 39 +++++++++++++++++++++++++++++++++++++++ env_test.go | 2 ++ 3 files changed, 77 insertions(+) diff --git a/check_test.go b/check_test.go index e0b505c..23207a3 100644 --- a/check_test.go +++ b/check_test.go @@ -157,6 +157,42 @@ func TestIsDialAddr(t *testing.T) { } } +func TestURL(t *testing.T) { + tests := []struct { + in string + wantErr bool + }{ + // Valid + {"http://localhost", false}, + {"http://localhost:1234", false}, + {"http://192.168.0.1:1234", false}, + {"https://example.com", false}, + {"https://example.com/home", false}, + {"https://example.com/home?a=b", false}, + + // Invalid + {"", true}, + {"ht tp://foo.com", true}, // invalid character in schema + {"http://a b.com/", true}, // no space in host name please + {"cache_object:foo", true}, + } + + env.ResetForTesting() + prefix := env.CmdVar.Name() + _ = env.URL("URL", "URL test") + name := strings.ToUpper(prefix) + "_URL" + + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + os.Setenv(name, tt.in) + + if err := env.Parse(); (err != nil) != tt.wantErr { + t.Errorf("env.Parse() = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestIsPath(t *testing.T) { env.ResetForTesting() diff --git a/env.go b/env.go index e765f96..3fb34ea 100644 --- a/env.go +++ b/env.go @@ -5,6 +5,7 @@ package env // import "code.sajari.com/env" import ( "errors" "fmt" + "net/url" "os" "path" "strconv" @@ -128,6 +129,30 @@ func (v *boolValue) String() string { return strconv.FormatBool(bool(*v)) } +type urlValue url.URL + +func newURLValue(x url.URL, p *url.URL) *urlValue { + *p = x + return (*urlValue)(p) +} + +func (v *urlValue) Set(x string) error { + if x == "" { + return errors.New("empty") + } + u, err := url.Parse(x) + if err != nil { + return err + } + *v = urlValue(*u) + return err +} + +func (v *urlValue) String() string { + u := url.URL(*v) + return u.String() +} + // NewVarSet creates a new variable set with given name. // // If name is non-empty, then all variables will have a strings.ToUpper(name)+"_" @@ -249,6 +274,14 @@ func (v *VarSet) DialAddr(name, usage string) *string { return p } +// URL defines a string variable with specified name, usage string validated as a URL. +// The return value is the address of a URL variable that stores the value of the variable. +func (v *VarSet) URL(name, usage string) *url.URL { + p := new(url.URL) + v.Var(newURLValue(url.URL{}, p), name, usage) + return p +} + // Path defines a string variable with specified name, usage string validated as a local path. // The return value is the address of a string variable that stores the value of the variable. func (v *VarSet) Path(name, usage string) *string { @@ -356,6 +389,12 @@ func DialAddr(name, usage string) *string { return CmdVar.DialAddr(name, usage) } +// URL defines a string variable with specified name, usage string validated as a URL. +// The return value is the address of a URL variable that stores the value of the variable. +func URL(name, usage string) *url.URL { + return CmdVar.URL(name, usage) +} + // Path defines a string variable with specified name, usage string validated as a // local path. // The return value is the address of a string variable that stores the value of the variable. diff --git a/env_test.go b/env_test.go index 4f5e169..24a3a51 100644 --- a/env_test.go +++ b/env_test.go @@ -20,6 +20,7 @@ func TestAll(t *testing.T) { env.Int("INT", "int test") env.BindAddr("LISTEN", "bindaddr test") env.DialAddr("ADDR", "dialaddr test") + env.URL("URL", "URL test") env.String("STRING", "string test") env.Duration("TIMEOUT", "timeout test") @@ -28,6 +29,7 @@ func TestAll(t *testing.T) { "TEST_INT": "1", "TEST_LISTEN": ":1234", "TEST_ADDR": "localhost:1234", + "TEST_URL": "http://localhost:1234/api", "TEST_STRING": "name", "TEST_TIMEOUT": "1m1s", }