Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions server/embed/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package embed

import (
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
Expand Down Expand Up @@ -771,12 +772,74 @@ func ConfigFromFile(path string) (*Config, error) {
return &cfg.Config, nil
}

// durationFieldKeys lists all JSON keys in Config that correspond to time.Duration fields.
// String values (e.g. "1m", "500ms") are converted to nanosecond integers before unmarshaling,
// because time.Duration does not implement json.Unmarshaler.
var durationFieldKeys = map[string]bool{
"backend-batch-interval": true,
"grpc-keepalive-min-time": true,
"grpc-keepalive-interval": true,
"grpc-keepalive-timeout": true,
"corrupt-check-time": true,
"compact-hash-check-time": true,
"compaction-sleep-interval": true,
"watch-progress-notify-interval": true,
"warning-apply-duration": true,
"warning-unary-request-duration": true,
"downgrade-check-time": true,
}

// preprocessDurationFields converts string duration values (e.g. "1m", "500ms") to
// nanosecond integers so that time.Duration fields unmarshal correctly from JSON/YAML.
func preprocessDurationFields(b []byte) ([]byte, error) {
var raw map[string]json.RawMessage
if err := yaml.Unmarshal(b, &raw); err != nil {
// If parsing as a map fails, return the original bytes and let the
// caller handle the error during normal unmarshaling.
return b, nil
}

modified := false
for key, val := range raw {
if !durationFieldKeys[key] {
continue
}
// Try to unmarshal as a string (e.g. "1m", "10s").
var s string
if err := json.Unmarshal(val, &s); err != nil {
// Not a string; might already be a number, which is fine.
continue
}
d, err := time.ParseDuration(s)
if err != nil {
return nil, fmt.Errorf("invalid duration value for %q: %w", key, err)
}
nsBytes, err := json.Marshal(d.Nanoseconds())
if err != nil {
return nil, err
}
raw[key] = nsBytes
modified = true
}

if !modified {
return b, nil
}

return yaml.Marshal(raw)
}

func (cfg *configYAML) configFromFile(path string) error {
b, err := os.ReadFile(path)
if err != nil {
return err
}

b, err = preprocessDurationFields(b)
if err != nil {
return err
}

defaultInitialCluster := cfg.InitialCluster

err = yaml.Unmarshal(b, cfg)
Expand Down
119 changes: 119 additions & 0 deletions server/embed/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package embed

import (
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
Expand Down Expand Up @@ -944,3 +945,121 @@ func TestFastLeaseKeepAliveValidate(t *testing.T) {
})
}
}

func TestConfigFileDurationFields(t *testing.T) {
testCases := []struct {
name string
config map[string]any
expectErr bool
check func(t *testing.T, cfg *Config)
}{
{
name: "string duration for watch-progress-notify-interval",
config: map[string]any{
"watch-progress-notify-interval": "1m",
},
check: func(t *testing.T, cfg *Config) {
require.Equal(t, time.Minute, cfg.WatchProgressNotifyInterval)
},
},
{
name: "numeric duration (nanoseconds) for watch-progress-notify-interval",
config: map[string]any{
"watch-progress-notify-interval": float64(time.Minute),
},
check: func(t *testing.T, cfg *Config) {
require.Equal(t, time.Minute, cfg.WatchProgressNotifyInterval)
},
},
{
name: "string durations for multiple fields",
config: map[string]any{
"watch-progress-notify-interval": "30s",
"backend-batch-interval": "500ms",
"grpc-keepalive-min-time": "5s",
"grpc-keepalive-interval": "2h",
"grpc-keepalive-timeout": "20s",
"corrupt-check-time": "4m",
"compact-hash-check-time": "2m",
"compaction-sleep-interval": "100ms",
"warning-apply-duration": "200ms",
"warning-unary-request-duration": "300ms",
"downgrade-check-time": "5s",
},
check: func(t *testing.T, cfg *Config) {
require.Equal(t, 30*time.Second, cfg.WatchProgressNotifyInterval)
require.Equal(t, 500*time.Millisecond, cfg.BackendBatchInterval)
require.Equal(t, 5*time.Second, cfg.GRPCKeepAliveMinTime)
require.Equal(t, 2*time.Hour, cfg.GRPCKeepAliveInterval)
require.Equal(t, 20*time.Second, cfg.GRPCKeepAliveTimeout)
require.Equal(t, 4*time.Minute, cfg.CorruptCheckTime)
require.Equal(t, 2*time.Minute, cfg.CompactHashCheckTime)
require.Equal(t, 100*time.Millisecond, cfg.CompactionSleepInterval)
require.Equal(t, 200*time.Millisecond, cfg.WarningApplyDuration)
require.Equal(t, 300*time.Millisecond, cfg.WarningUnaryRequestDuration)
require.Equal(t, 5*time.Second, cfg.DowngradeCheckTime)
},
},
{
name: "invalid duration string",
config: map[string]any{
"watch-progress-notify-interval": "not-a-duration",
},
expectErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
b, err := json.Marshal(tc.config)
require.NoError(t, err)

tmpfile := mustCreateCfgFile(t, b)
defer os.Remove(tmpfile.Name())

cfg, err := ConfigFromFile(tmpfile.Name())
if tc.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
tc.check(t, cfg)
})
}
}

func TestPreprocessDurationFields(t *testing.T) {
testCases := []struct {
name string
input string
expectErr bool
}{
{
name: "string duration value",
input: `{"watch-progress-notify-interval": "1m"}`,
},
{
name: "numeric duration value passes through",
input: `{"watch-progress-notify-interval": 60000000000}`,
},
{
name: "non-duration field unchanged",
input: `{"name": "my-etcd"}`,
},
{
name: "invalid duration string",
input: `{"watch-progress-notify-interval": "invalid"}`,
expectErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := preprocessDurationFields([]byte(tc.input))
if tc.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotEmpty(t, result)
})
}
}