From 10a465d0ff6794f4b19160c299905e5daa0cc513 Mon Sep 17 00:00:00 2001 From: Kaloyan Tanev <24719519+KaloyanTanev@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:33:35 +0200 Subject: [PATCH] cmd: move threshold check to CLI level (#3297) A [recent commit](https://github.com/ObolNetwork/charon/commit/98b84e11548e7267e2440fd6f582a9e8fcdfe970) introduced a misbehavior when omitting the optional `--threshold` flag of `create dkg` and `create cluster` commands. Because the threshold configuration is tested before the threshold variable is assigned to the default value `ceil(2*n/3)`, the flag is not optional anymore. This PR fixes this bug by moving the checks at the CLI level and by updating the corresponding tests accordingly. It also adds an input validation check on the [`ThresholdSplit`](https://github.com/ObolNetwork/charon/blob/ced30abb5a8c168b358a9bfc976fbe23927d72de/tbls/herumi.go#L133) and [`ThresholdSplitInsecure`](https://github.com/ObolNetwork/charon/blob/ced30abb5a8c168b358a9bfc976fbe23927d72de/tbls/herumi.go#L83) functions to ensure they are called with a threshold parameter greater than 1. category: bug ticket: none --- cmd/createcluster.go | 27 ++++--- cmd/createcluster_internal_test.go | 96 +++++++++++++++++++------ cmd/createdkg.go | 33 +++++---- cmd/createdkg_internal_test.go | 109 ++++++++++++++++++++++++----- tbls/herumi.go | 8 +++ 5 files changed, 210 insertions(+), 63 deletions(-) diff --git a/cmd/createcluster.go b/cmd/createcluster.go index 249feec92..065c713af 100644 --- a/cmd/createcluster.go +++ b/cmd/createcluster.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "io" - "math" "net/url" "os" "path" @@ -100,6 +99,22 @@ func newCreateClusterCmd(runFunc func(context.Context, io.Writer, clusterConfig) bindClusterFlags(cmd.Flags(), &conf) bindInsecureFlags(cmd.Flags(), &conf.InsecureKeys) + wrapPreRunE(cmd, func(cmd *cobra.Command, _ []string) error { + thresholdPresent := cmd.Flags().Lookup("threshold").Changed + + if thresholdPresent { + if conf.Threshold < minThreshold { + return errors.New("threshold must be greater than 1", z.Int("threshold", conf.Threshold), z.Int("min", minThreshold)) + } + if conf.Threshold > conf.NumNodes { + return errors.New("threshold cannot be greater than number of operators", + z.Int("threshold", conf.Threshold), z.Int("operators", conf.NumNodes)) + } + } + + return nil + }) + return cmd } @@ -374,16 +389,6 @@ func validateCreateConfig(ctx context.Context, conf clusterConfig) error { return errors.New("number of operators is below minimum", z.Int("operators", conf.NumNodes), z.Int("min", minNodes)) } - // Check for threshold parameter - minThreshold := int(math.Ceil(float64(conf.NumNodes*2) / 3)) - if conf.Threshold < minThreshold { - return errors.New("threshold cannot be smaller than BFT quorum", z.Int("threshold", conf.Threshold), z.Int("min", minThreshold)) - } - if conf.Threshold > conf.NumNodes { - return errors.New("threshold cannot be greater than number of operators", - z.Int("threshold", conf.Threshold), z.Int("operators", conf.NumNodes)) - } - return nil } diff --git a/cmd/createcluster_internal_test.go b/cmd/createcluster_internal_test.go index 95be0bad4..a611206c1 100644 --- a/cmd/createcluster_internal_test.go +++ b/cmd/createcluster_internal_test.go @@ -250,26 +250,6 @@ func TestCreateCluster(t *testing.T) { }, }, }, - { - Name: "threshold greater than the number of operators", - Config: clusterConfig{ - NumNodes: 4, - Threshold: 5, - NumDVs: 1, - Network: defaultNetwork, - }, - expectedErr: "threshold cannot be greater than number of operators", - }, - { - Name: "threshold smaller than BFT quorum", - Config: clusterConfig{ - NumNodes: 4, - Threshold: 2, - NumDVs: 1, - Network: defaultNetwork, - }, - expectedErr: "threshold cannot be smaller than BFT quorum", - }, { Name: "test with number of nodes below minimum", Config: clusterConfig{ @@ -788,6 +768,82 @@ func TestPublish(t *testing.T) { }) } +func TestClusterCLI(t *testing.T) { + feeRecipientArg := "--fee-recipient-addresses=" + validEthAddr + withdrawalArg := "--withdrawal-addresses=" + validEthAddr + + tests := []struct { + name string + network string + nodes string + numValidators string + feeRecipient string + withdrawal string + threshold string + expectedErr string + cleanup func(*testing.T) + }{ + { + name: "threshold below minimum", + nodes: "--nodes=3", + network: "--network=holesky", + numValidators: "--num-validators=1", + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + threshold: "--threshold=1", + expectedErr: "threshold must be greater than 1", + }, + { + name: "threshold above maximum", + nodes: "--nodes=4", + network: "--network=holesky", + numValidators: "--num-validators=1", + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + threshold: "--threshold=5", + expectedErr: "threshold cannot be greater than number of operators", + }, + { + name: "no threshold provided", + nodes: "--nodes=3", + network: "--network=holesky", + numValidators: "--num-validators=1", + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + threshold: "", + expectedErr: "", + cleanup: func(t *testing.T) { + t.Helper() + require.NoError(t, os.RemoveAll("node0")) + require.NoError(t, os.RemoveAll("node1")) + require.NoError(t, os.RemoveAll("node2")) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cmd := newCreateCmd(newCreateClusterCmd(runCreateCluster)) + if test.threshold != "" { + cmd.SetArgs([]string{"cluster", test.nodes, test.feeRecipient, test.withdrawal, test.network, test.numValidators, test.threshold}) + } else { + cmd.SetArgs([]string{"cluster", test.nodes, test.feeRecipient, test.withdrawal, test.network, test.numValidators}) + } + + err := cmd.Execute() + if test.expectedErr != "" { + require.ErrorContains(t, err, test.expectedErr) + } else { + require.NoError(t, err) + } + + if test.cleanup != nil { + test.cleanup(t) + } + }) + } +} + // mockKeymanagerReq is a mock keymanager request for use in tests. type mockKeymanagerReq struct { Keystores []string `json:"keystores"` diff --git a/cmd/createdkg.go b/cmd/createdkg.go index 79fe1f476..4bc9e28ab 100644 --- a/cmd/createdkg.go +++ b/cmd/createdkg.go @@ -6,7 +6,6 @@ import ( "context" crand "crypto/rand" "encoding/json" - "math" "os" "path" @@ -50,6 +49,22 @@ func newCreateDKGCmd(runFunc func(context.Context, createDKGConfig) error) *cobr bindCreateDKGFlags(cmd, &config) + wrapPreRunE(cmd, func(cmd *cobra.Command, _ []string) error { + thresholdPresent := cmd.Flags().Lookup("threshold").Changed + + if thresholdPresent { + if config.Threshold < minThreshold { + return errors.New("threshold must be greater than 1", z.Int("threshold", config.Threshold), z.Int("min", minThreshold)) + } + if config.Threshold > len(config.OperatorENRs) { + return errors.New("threshold cannot be greater than number of operators", + z.Int("threshold", config.Threshold), z.Int("operators", len(config.OperatorENRs))) + } + } + + return nil + }) + return cmd } @@ -82,7 +97,7 @@ func runCreateDKG(ctx context.Context, conf createDKGConfig) (err error) { conf.Network = eth2util.Goerli.Name } - if err = validateDKGConfig(conf.Threshold, len(conf.OperatorENRs), conf.Network, conf.DepositAmounts); err != nil { + if err = validateDKGConfig(len(conf.OperatorENRs), conf.Network, conf.DepositAmounts); err != nil { return err } @@ -115,7 +130,7 @@ func runCreateDKG(ctx context.Context, conf createDKGConfig) (err error) { safeThreshold := cluster.Threshold(len(conf.OperatorENRs)) if conf.Threshold == 0 { conf.Threshold = safeThreshold - } else if conf.Threshold != safeThreshold { + } else { log.Warn(ctx, "Non standard `--threshold` flag provided, this will affect cluster safety", nil, z.Int("threshold", conf.Threshold), z.Int("safe_threshold", safeThreshold)) } @@ -181,22 +196,12 @@ func validateWithdrawalAddrs(addrs []string, network string) error { } // validateDKGConfig returns an error if any of the provided config parameter is invalid. -func validateDKGConfig(threshold, numOperators int, network string, depositAmounts []int) error { +func validateDKGConfig(numOperators int, network string, depositAmounts []int) error { // Don't allow cluster size to be less than 3. if numOperators < minNodes { return errors.New("number of operators is below minimum", z.Int("operators", numOperators), z.Int("min", minNodes)) } - // Ensure threshold setting is sound - minThreshold := int(math.Ceil(float64(numOperators*2) / 3)) - if threshold < minThreshold { - return errors.New("threshold cannot be smaller than BFT quorum", z.Int("threshold", threshold), z.Int("min", minThreshold)) - } - if threshold > numOperators { - return errors.New("threshold cannot be greater than length of operators", - z.Int("threshold", threshold), z.Int("operators", numOperators)) - } - if !eth2util.ValidNetwork(network) { return errors.New("unsupported network", z.Str("network", network)) } diff --git a/cmd/createdkg_internal_test.go b/cmd/createdkg_internal_test.go index a6d522db5..3e93ebaff 100644 --- a/cmd/createdkg_internal_test.go +++ b/cmd/createdkg_internal_test.go @@ -184,36 +184,109 @@ func TestValidateWithdrawalAddr(t *testing.T) { } func TestValidateDKGConfig(t *testing.T) { - t.Run("threshold exceeds numOperators", func(t *testing.T) { - threshold := 5 - numOperators := 4 - err := validateDKGConfig(threshold, numOperators, "", nil) - require.ErrorContains(t, err, "threshold cannot be greater than length of operators") - }) - - t.Run("threshold equals 1", func(t *testing.T) { - threshold := 1 - numOperators := 3 - err := validateDKGConfig(threshold, numOperators, "", nil) - require.ErrorContains(t, err, "threshold cannot be smaller than BFT quorum") - }) t.Run("insufficient ENRs", func(t *testing.T) { - threshold := 2 numOperators := 2 - err := validateDKGConfig(threshold, numOperators, "", nil) + err := validateDKGConfig(numOperators, "", nil) require.ErrorContains(t, err, "number of operators is below minimum") }) t.Run("invalid network", func(t *testing.T) { - threshold := 3 numOperators := 4 - err := validateDKGConfig(threshold, numOperators, "cosmos", nil) + err := validateDKGConfig(numOperators, "cosmos", nil) require.ErrorContains(t, err, "unsupported network") }) t.Run("wrong deposit amounts sum", func(t *testing.T) { - err := validateDKGConfig(3, 4, "goerli", []int{8, 16}) + err := validateDKGConfig(4, "goerli", []int{8, 16}) require.ErrorContains(t, err, "sum of partial deposit amounts must sum up to 32ETH") }) } + +func TestDKGCLI(t *testing.T) { + var enrs []string + for range minNodes { + enrs = append(enrs, "enr:-JG4QG472ZVvl8ySSnUK9uNVDrP_hjkUrUqIxUC75aayzmDVQedXkjbqc7QKyOOS71VmlqnYzri_taV8ZesFYaoQSIOGAYHtv1WsgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQKwwq_CAld6oVKOrixE-JzMtvvNgb9yyI-_rwq4NFtajIN0Y3CCDhqDdWRwgg4u") + } + enrArg := "--operator-enrs=" + strings.Join(enrs, ",") + feeRecipientArg := "--fee-recipient-addresses=" + validEthAddr + withdrawalArg := "--withdrawal-addresses=" + validEthAddr + outputDirArg := "--output-dir=.charon" + + tests := []struct { + name string + enr string + feeRecipient string + withdrawal string + outputDir string + threshold string + expectedErr string + prepare func(*testing.T) + cleanup func(*testing.T) + }{ + { + name: "threshold below minimum", + enr: enrArg, + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + outputDir: outputDirArg, + threshold: "--threshold=1", + expectedErr: "threshold must be greater than 1", + }, + { + name: "threshold above maximum", + enr: enrArg, + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + outputDir: outputDirArg, + threshold: "--threshold=4", + expectedErr: "threshold cannot be greater than number of operators", + }, + { + name: "no threshold provided", + enr: enrArg, + feeRecipient: feeRecipientArg, + withdrawal: withdrawalArg, + outputDir: outputDirArg, + threshold: "", + expectedErr: "", + prepare: func(t *testing.T) { + t.Helper() + charonDir := testutil.CreateTempCharonDir(t) + b := []byte("sample definition") + require.NoError(t, os.WriteFile(path.Join(charonDir, "cluster-definition.json"), b, 0o600)) + }, + cleanup: func(t *testing.T) { + t.Helper() + err := os.RemoveAll(".charon") + require.NoError(t, err) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.prepare != nil { + test.prepare(t) + } + + cmd := newCreateCmd(newCreateDKGCmd(runCreateDKG)) + if test.threshold != "" { + cmd.SetArgs([]string{"dkg", test.enr, test.feeRecipient, test.withdrawal, test.outputDir, test.threshold}) + } else { + cmd.SetArgs([]string{"dkg", test.enr, test.feeRecipient, test.withdrawal, test.outputDir}) + } + + err := cmd.Execute() + if test.expectedErr != "" { + require.ErrorContains(t, err, test.expectedErr) + } else { + require.NoError(t, err) + } + + if test.cleanup != nil { + test.cleanup(t) + } + }) + } +} diff --git a/tbls/herumi.go b/tbls/herumi.go index 0be2ef2a9..0c50c1599 100644 --- a/tbls/herumi.go +++ b/tbls/herumi.go @@ -84,6 +84,10 @@ func (Herumi) ThresholdSplitInsecure(t *testing.T, secret PrivateKey, total uint t.Helper() var p bls.SecretKey + if threshold <= 1 { + return nil, errors.New("threshold has to be greater than 1") + } + if err := p.Deserialize(secret[:]); err != nil { return nil, errors.Wrap(err, "cannot unmarshal bytes into Herumi secret key") } @@ -133,6 +137,10 @@ func (Herumi) ThresholdSplitInsecure(t *testing.T, secret PrivateKey, total uint func (Herumi) ThresholdSplit(secret PrivateKey, total uint, threshold uint) (map[int]PrivateKey, error) { var p bls.SecretKey + if threshold <= 1 { + return nil, errors.New("threshold has to be greater than 1") + } + if err := p.Deserialize(secret[:]); err != nil { return nil, errors.Wrap(err, "cannot unmarshal bytes into Herumi secret key") }