Skip to content

Commit

Permalink
feat(cmd): New Uint16 flag type
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoghx committed Aug 30, 2024
1 parent 6d2313a commit 5b81bf9
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 0 deletions.
64 changes: 64 additions & 0 deletions internal/cmd/base/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,70 @@ func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i.tar
func (i *uint64Value) Example() string { return "uint" }
func (i *uint64Value) Hidden() bool { return i.hidden }

// -- Uint16Var and uint16Value
type Uint16Var struct {
Name string
Aliases []string
Usage string
Default uint16
Hidden bool
EnvVar string
Target *uint16
Completion complete.Predictor
}

func (f *FlagSet) Uint16Var(i *Uint16Var) {
initial := i.Default
if v, exist := os.LookupEnv(i.EnvVar); exist {
if i, err := strconv.ParseUint(v, 0, 16); err == nil {
initial = uint16(i)
}
}

def := ""
if i.Default != 0 {
strconv.FormatUint(uint64(i.Default), 10)
}

f.VarFlag(&VarFlag{
Name: i.Name,
Aliases: i.Aliases,
Usage: i.Usage,
Default: def,
EnvVar: i.EnvVar,
Value: newUint16Value(initial, i.Target, i.Hidden),
Completion: i.Completion,
})
}

type uint16Value struct {
hidden bool
target *uint16
}

func newUint16Value(def uint16, target *uint16, hidden bool) *uint16Value {
*target = def
return &uint16Value{
hidden: hidden,
target: target,
}
}

func (i *uint16Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 16)
if err != nil {
return err
}

*i.target = uint16(v)
return nil
}

func (i *uint16Value) Get() any { return uint64(*i.target) }
func (i *uint16Value) String() string { return strconv.FormatUint(uint64(*i.target), 10) }
func (i *uint16Value) Example() string { return "uint" }
func (i *uint16Value) Hidden() bool { return i.hidden }

// -- StringVar and stringValue
type StringVar struct {
Name string
Expand Down
89 changes: 89 additions & 0 deletions internal/cmd/base/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package base

import (
"os"
"testing"

"github.com/mitchellh/cli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -152,3 +154,90 @@ func TestFlagSet_StringSliceMapVar_NullCheck(t *testing.T) {
})
}
}

func TestUint16Var(t *testing.T) {
t.Parallel()

target := uint16(0)

flagSets := NewFlagSets(cli.NewMockUi())
f := flagSets.NewFlagSet("testset")
f.Uint16Var(&Uint16Var{
Name: "test_name",
Aliases: []string{"test_alias1"},
Usage: "test_usage",
Default: 1,
Hidden: false,
Target: &target,
})
require.Equal(t, uint16(1), target) // Should immediately default.

// Value that overflows uint16 should error.
err := flagSets.Parse([]string{"-test_name", "66000"})
require.EqualError(t, err, "invalid value \"66000\" for flag -test_name: strconv.ParseUint: parsing \"66000\": value out of range")
require.Equal(t, uint16(1), target)

// Value that overflows uint16 (via alias) should error.
err = flagSets.Parse([]string{"-test_alias1", "66000"})
require.EqualError(t, err, "invalid value \"66000\" for flag -test_alias1: strconv.ParseUint: parsing \"66000\": value out of range")
require.Equal(t, uint16(1), target)

// Negative value should error.
err = flagSets.Parse([]string{"-test_name", "-1"})
require.EqualError(t, err, "invalid value \"-1\" for flag -test_name: strconv.ParseUint: parsing \"-1\": invalid syntax")
require.Equal(t, uint16(1), target)

// Negative value (via alias) should error.
err = flagSets.Parse([]string{"-test_alias1", "-1"})
require.EqualError(t, err, "invalid value \"-1\" for flag -test_alias1: strconv.ParseUint: parsing \"-1\": invalid syntax")
require.Equal(t, uint16(1), target)

// Valid value should be put into target.
err = flagSets.Parse([]string{"-test_name", "123"})
require.NoError(t, err)
require.Equal(t, uint16(123), target)

// Valid value (using alias) should be put into target.
err = flagSets.Parse([]string{"-test_alias1", "456"})
require.NoError(t, err)
require.Equal(t, uint16(456), target)

// Env var tests.
envTarget := uint16(0)
envVarName := "test_uint16_env_var"

envFlagSets := NewFlagSets(cli.NewMockUi())
ef := envFlagSets.NewFlagSet("env_testset")

require.NoError(t, os.Setenv(envVarName, "66000"))
ef.Uint16Var(&Uint16Var{
Name: "test_env_name1",
Default: 1,
EnvVar: envVarName,
Target: &envTarget,
})
require.Equal(t, uint16(1), envTarget) // Should be set to default because env value parse will have failed.
require.NoError(t, os.Unsetenv(envVarName))
envTarget = uint16(0)

require.NoError(t, os.Setenv(envVarName, "-1"))
ef.Uint16Var(&Uint16Var{
Name: "test_env_name2",
Default: 1,
EnvVar: envVarName,
Target: &envTarget,
})
require.Equal(t, uint16(1), envTarget) // Should be set to default because env value parse will have failed.
require.NoError(t, os.Unsetenv(envVarName))
envTarget = uint16(0)

require.NoError(t, os.Setenv(envVarName, "123"))
ef.Uint16Var(&Uint16Var{
Name: "test_env_name3",
Default: 1,
EnvVar: envVarName,
Target: &envTarget,
})
require.Equal(t, uint16(123), envTarget) // Should be set to what was set in env.
require.NoError(t, os.Unsetenv(envVarName))
}

0 comments on commit 5b81bf9

Please sign in to comment.