diff --git a/cmd/zetacored/flags.go b/cmd/zetacored/flags.go new file mode 100644 index 0000000000..2df863b377 --- /dev/null +++ b/cmd/zetacored/flags.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +var KeyAddCommand = []string{"keys", "add"} + +const ( + HDPathFlag = "hd-path" + HDPathEthereum = "m/44'/60'/0'/0/0" +) + +// SetEthereumHDPath sets the default HD path to Ethereum's +func SetEthereumHDPath(cmd *cobra.Command) error { + return ReplaceFlag(cmd, KeyAddCommand, HDPathFlag, HDPathEthereum) +} + +// ReplaceFlag replaces the default value of a flag of a sub-command +func ReplaceFlag(cmd *cobra.Command, subCommand []string, flagName, newDefaultValue string) error { + // Find the sub-command + c, _, err := cmd.Find(subCommand) + if err != nil { + return fmt.Errorf("failed to find %v sub-command: %v", subCommand, err) + } + + // Get the flag from the sub-command + f := c.Flags().Lookup(flagName) + if f == nil { + return fmt.Errorf("%s flag not found in %v sub-command", flagName, subCommand) + } + + // Set the default value for the flag + f.DefValue = newDefaultValue + if err := f.Value.Set(newDefaultValue); err != nil { + return fmt.Errorf("failed to set the value of %s flag: %v", flagName, err) + } + + return nil +} diff --git a/cmd/zetacored/flags_test.go b/cmd/zetacored/flags_test.go new file mode 100644 index 0000000000..1fbdc4d4d0 --- /dev/null +++ b/cmd/zetacored/flags_test.go @@ -0,0 +1,92 @@ +package main_test + +import ( + "errors" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" + zetacore "github.com/zeta-chain/zetacore/cmd/zetacored" +) + +// alwaysErrorValue allows to test f.Value.Set failure +type alwaysErrorValue struct{} + +func (a *alwaysErrorValue) Set(string) error { return errors.New("error") } +func (a *alwaysErrorValue) String() string { return "" } +func (a *alwaysErrorValue) Type() string { return "string" } + +func TestReplaceFlag(t *testing.T) { + // Setting up a mock command structure + rootCmd := &cobra.Command{Use: "app"} + + fooCmd := &cobra.Command{Use: "foo"} + barCmd := &cobra.Command{Use: "bar"} + + barCmd.Flags().String("baz", "old", "Bar") + barCmd.Flags().Var(&alwaysErrorValue{}, "error", "Always fails to set") + + fooCmd.AddCommand(barCmd) + rootCmd.AddCommand(fooCmd) + + tests := []struct { + name string + cmd *cobra.Command + subCommand []string + flagName string + newDefaultValue string + wantErr bool + expectedValue string + }{ + { + name: "Replace valid flag", + cmd: rootCmd, + subCommand: []string{"foo", "bar"}, + flagName: "baz", + newDefaultValue: "new", + wantErr: false, + expectedValue: "new", + }, + { + name: "Sub-command not found", + cmd: rootCmd, + subCommand: []string{"key", "nonexistent"}, + flagName: "baz", + newDefaultValue: "new", + wantErr: true, + }, + { + name: "Flag not found", + cmd: rootCmd, + subCommand: []string{"foo", "bar"}, + flagName: "nonexistent", + newDefaultValue: "new", + wantErr: true, + }, + { + name: "Flag value cannot be set", + cmd: rootCmd, + subCommand: []string{"foo", "bar"}, + flagName: "error", + newDefaultValue: "new", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := zetacore.ReplaceFlag(tt.cmd, tt.subCommand, tt.flagName, tt.newDefaultValue) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + + // Check if the value was replaced correctly + c, _, _ := tt.cmd.Find(tt.subCommand) + f := c.Flags().Lookup(tt.flagName) + require.Equal(t, tt.expectedValue, f.DefValue) + } + }) + } +} diff --git a/cmd/zetacored/root.go b/cmd/zetacored/root.go index c4635ebd8f..84713de816 100644 --- a/cmd/zetacored/root.go +++ b/cmd/zetacored/root.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "io" "os" "path/filepath" @@ -142,6 +143,11 @@ func initRootCmd(rootCmd *cobra.Command, encodingConfig appparams.EncodingConfig ethermintclient.KeyCommands(app.DefaultNodeHome), ) + // replace the default hd-path for the key add command + if err := SetEthereumHDPath(rootCmd); err != nil { + fmt.Printf("warning: unable to set default HD path: %v\n", err) + } + rootCmd.AddCommand(server.RosettaCommand(encodingConfig.InterfaceRegistry, encodingConfig.Codec)) }