diff --git a/config/config.go b/config/config.go index 214beae6..4a9fece6 100644 --- a/config/config.go +++ b/config/config.go @@ -45,7 +45,7 @@ func (w WatchdogConfig) Process() (string, []string) { return parts[0], parts[1:] } - return parts[0], []string{} + return parts[0], nil } // New create config based upon environmental variables. @@ -101,7 +101,7 @@ func New(env []string) WatchdogConfig { MaxInflight: getInt(envMap, "max_inflight", 0), } - if val := envMap["mode"]; len(val) > 0 { + if val, exists := envMap["mode"]; exists { config.OperationalMode = WatchdogModeConst(val) } @@ -112,17 +112,13 @@ func mapEnv(env []string) map[string]string { mapped := map[string]string{} for _, val := range env { - sep := strings.Index(val, "=") - - if sep > 0 { - key := val[0:sep] - value := val[sep+1:] - mapped[key] = value - } else { - fmt.Println("Bad environment: " + val) + keyValue := strings.SplitN(val, "=", 2) + if len(keyValue) == 2 && keyValue[1] != "" { + mapped[keyValue[0]] = keyValue[1] + continue } + fmt.Println("Bad environment: " + val) } - return mapped } diff --git a/config/config_test.go b/config/config_test.go index 3b4b3725..b6b36d07 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "reflect" "testing" "time" ) @@ -347,3 +348,56 @@ func Test_NonParsableString_parseIntOrDurationValue(t *testing.T) { t.Error(fmt.Sprintf("want: %q got: %q", want, got)) } } + +func Test_mapEnv(t *testing.T) { + type args struct { + env []string + } + tests := []struct { + name string + args args + want map[string]string + }{ + { + name: "change env to map", + args: args{ + env: []string{ + "FOO=BAR", + }, + }, + want: map[string]string{ + "FOO": "BAR", + }, + }, + { + name: "keep '=' of environment value when contains '='", + args: args{ + env: []string{ + "FOO=BAR=BAZ", + }, + }, + want: map[string]string{ + "FOO": "BAR=BAZ", + }, + }, + { + name: "ignore empty value environment", + args: args{ + env: []string{ + "FOO=", + "BAR=BAZ", + }, + }, + want: map[string]string{ + "BAR": "BAZ", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := mapEnv(tt.args.env); !reflect.DeepEqual(got, tt.want) { + t.Errorf("mapEnv() = %v, want %v", got, tt.want) + } + }) + } +}