diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index aceecc47ad56..24bdabfa96da 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -42,6 +42,8 @@ type diff struct { // Subnet ID --> supply of native asset of the subnet currentSupply map[ids.ID]uint64 + expiryDiff *expiryDiff + currentStakerDiffs diffStakers // map of subnetID -> nodeID -> total accrued delegatee rewards modifiedDelegateeRewards map[ids.ID]map[ids.NodeID]uint64 @@ -79,6 +81,7 @@ func NewDiff( timestamp: parentState.GetTimestamp(), feeState: parentState.GetFeeState(), accruedFees: parentState.GetAccruedFees(), + expiryDiff: newExpiryDiff(), subnetOwners: make(map[ids.ID]fx.Owner), subnetManagers: make(map[ids.ID]chainIDAndAddr), }, nil @@ -146,6 +149,41 @@ func (d *diff) SetCurrentSupply(subnetID ids.ID, currentSupply uint64) { } } +func (d *diff) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + parentIterator, err := parentState.GetExpiryIterator() + if err != nil { + return nil, err + } + + return d.expiryDiff.getExpiryIterator(parentIterator), nil +} + +func (d *diff) HasExpiry(entry ExpiryEntry) (bool, error) { + if has, modified := d.expiryDiff.modified[entry]; modified { + return has, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return false, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + return parentState.HasExpiry(entry) +} + +func (d *diff) PutExpiry(entry ExpiryEntry) { + d.expiryDiff.PutExpiry(entry) +} + +func (d *diff) DeleteExpiry(entry ExpiryEntry) { + d.expiryDiff.DeleteExpiry(entry) +} + func (d *diff) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { // If the validator was modified in this diff, return the modified // validator. @@ -451,6 +489,13 @@ func (d *diff) Apply(baseState Chain) error { for subnetID, supply := range d.currentSupply { baseState.SetCurrentSupply(subnetID, supply) } + for entry, isAdded := range d.expiryDiff.modified { + if isAdded { + baseState.PutExpiry(entry) + } else { + baseState.DeleteExpiry(entry) + } + } for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs { for _, validatorDiff := range subnetValidatorDiffs { switch validatorDiff.validatorStatus { diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index c4246b8aaeec..b71599928e23 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -17,6 +17,7 @@ import ( "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/iterator/iteratormock" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/gas" "github.com/ava-labs/avalanchego/vms/platformvm/fx/fxmock" @@ -112,6 +113,156 @@ func TestDiffCurrentSupply(t *testing.T) { assertChainsEqual(t, state, d) } +func TestDiffExpiry(t *testing.T) { + type op struct { + put bool + entry ExpiryEntry + } + tests := []struct { + name string + initialExpiries []ExpiryEntry + ops []op + }{ + { + name: "empty noop", + }, + { + name: "insert", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "remove", + initialExpiries: []ExpiryEntry{ + {Timestamp: 1}, + }, + ops: []op{ + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "add and immediately remove", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "add + remove + add", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "everything", + initialExpiries: []ExpiryEntry{ + {Timestamp: 1}, + {Timestamp: 2}, + {Timestamp: 3}, + }, + ops: []op{ + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 2}, + }, + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + } + + for _, test := range tests { + require := require.New(t) + + state := newTestState(t, memdb.New()) + for _, expiry := range test.initialExpiries { + state.PutExpiry(expiry) + } + + d, err := NewDiffOn(state) + require.NoError(err) + + var ( + expectedExpiries = set.Of(test.initialExpiries...) + unexpectedExpiries set.Set[ExpiryEntry] + ) + for _, op := range test.ops { + if op.put { + d.PutExpiry(op.entry) + expectedExpiries.Add(op.entry) + unexpectedExpiries.Remove(op.entry) + } else { + d.DeleteExpiry(op.entry) + expectedExpiries.Remove(op.entry) + unexpectedExpiries.Add(op.entry) + } + } + + // If expectedExpiries is empty, we want expectedExpiriesSlice to be + // nil. + var expectedExpiriesSlice []ExpiryEntry + if expectedExpiries.Len() > 0 { + expectedExpiriesSlice = expectedExpiries.List() + utils.Sort(expectedExpiriesSlice) + } + + verifyChain := func(chain Chain) { + expiryIterator, err := chain.GetExpiryIterator() + require.NoError(err) + require.Equal( + expectedExpiriesSlice, + iterator.ToSlice(expiryIterator), + ) + + for expiry := range expectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.True(has) + } + for expiry := range unexpectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.False(has) + } + } + + verifyChain(d) + require.NoError(d.Apply(state)) + verifyChain(state) + assertChainsEqual(t, d, state) + } +} + func TestDiffCurrentValidator(t *testing.T) { require := require.New(t) ctrl := gomock.NewController(t) @@ -527,6 +678,16 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { t.Helper() + expectedExpiryIterator, expectedErr := expected.GetExpiryIterator() + actualExpiryIterator, actualErr := actual.GetExpiryIterator() + require.Equal(expectedErr, actualErr) + if expectedErr == nil { + require.Equal( + iterator.ToSlice(expectedExpiryIterator), + iterator.ToSlice(actualExpiryIterator), + ) + } + expectedCurrentStakerIterator, expectedErr := expected.GetCurrentStakerIterator() actualCurrentStakerIterator, actualErr := actual.GetCurrentStakerIterator() require.Equal(expectedErr, actualErr) diff --git a/vms/platformvm/state/expiry.go b/vms/platformvm/state/expiry.go index d2b392b62f4b..b50439ddf20c 100644 --- a/vms/platformvm/state/expiry.go +++ b/vms/platformvm/state/expiry.go @@ -11,6 +11,8 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/iterator" ) // expiryEntry = [timestamp] + [validationID] @@ -20,8 +22,26 @@ var ( errUnexpectedExpiryEntryLength = fmt.Errorf("expected expiry entry length %d", expiryEntryLength) _ btree.LessFunc[ExpiryEntry] = ExpiryEntry.Less + _ utils.Sortable[ExpiryEntry] = ExpiryEntry{} ) +type Expiry interface { + // GetExpiryIterator returns an iterator of all the expiry entries in order + // of lowest to highest timestamp. + GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) + + // HasExpiry returns true if the database has the specified entry. + HasExpiry(ExpiryEntry) (bool, error) + + // PutExpiry adds the entry to the database. If the entry already exists, it + // is a noop. + PutExpiry(ExpiryEntry) + + // DeleteExpiry removes the entry from the database. If the entry doesn't + // exist, it is a noop. + DeleteExpiry(ExpiryEntry) +} + type ExpiryEntry struct { Timestamp uint64 ValidationID ids.ID @@ -44,14 +64,51 @@ func (e *ExpiryEntry) Unmarshal(data []byte) error { return nil } -// Invariant: Less produces the same ordering as the marshalled bytes. func (e ExpiryEntry) Less(o ExpiryEntry) bool { + return e.Compare(o) == -1 +} + +// Invariant: Compare produces the same ordering as the marshalled bytes. +func (e ExpiryEntry) Compare(o ExpiryEntry) int { switch { case e.Timestamp < o.Timestamp: - return true + return -1 case e.Timestamp > o.Timestamp: - return false + return 1 default: - return e.ValidationID.Compare(o.ValidationID) == -1 + return e.ValidationID.Compare(o.ValidationID) + } +} + +type expiryDiff struct { + modified map[ExpiryEntry]bool // bool represents isAdded + added *btree.BTreeG[ExpiryEntry] +} + +func newExpiryDiff() *expiryDiff { + return &expiryDiff{ + modified: make(map[ExpiryEntry]bool), + added: btree.NewG(defaultTreeDegree, ExpiryEntry.Less), } } + +func (e *expiryDiff) PutExpiry(entry ExpiryEntry) { + e.modified[entry] = true + e.added.ReplaceOrInsert(entry) +} + +func (e *expiryDiff) DeleteExpiry(entry ExpiryEntry) { + e.modified[entry] = false + e.added.Delete(entry) +} + +func (e *expiryDiff) getExpiryIterator(parentIterator iterator.Iterator[ExpiryEntry]) iterator.Iterator[ExpiryEntry] { + return iterator.Merge( + ExpiryEntry.Less, + iterator.Filter(parentIterator, func(entry ExpiryEntry) bool { + _, ok := e.modified[entry] + return ok + }), + iterator.FromTree(e.added), + ) +} diff --git a/vms/platformvm/state/mock_chain.go b/vms/platformvm/state/mock_chain.go index 1891b74099d8..3b380a87a8b8 100644 --- a/vms/platformvm/state/mock_chain.go +++ b/vms/platformvm/state/mock_chain.go @@ -142,6 +142,18 @@ func (mr *MockChainMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockChain)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockChain) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockChainMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockChain)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockChain) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -267,6 +279,21 @@ func (mr *MockChainMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockChain)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockChain) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockChainMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockChain)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockChain) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -417,6 +444,21 @@ func (mr *MockChainMockRecorder) GetUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUTXO", reflect.TypeOf((*MockChain)(nil).GetUTXO), utxoID) } +// HasExpiry mocks base method. +func (m *MockChain) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockChainMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockChain)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockChain) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -443,6 +485,18 @@ func (mr *MockChainMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockChain)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockChain) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockChainMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockChain)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockChain) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_diff.go b/vms/platformvm/state/mock_diff.go index 557c59faf58c..77edfde92aaf 100644 --- a/vms/platformvm/state/mock_diff.go +++ b/vms/platformvm/state/mock_diff.go @@ -156,6 +156,18 @@ func (mr *MockDiffMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockDiff)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockDiff) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockDiffMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockDiff)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockDiff) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -281,6 +293,21 @@ func (mr *MockDiffMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockDiff)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockDiff) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockDiffMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockDiff)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockDiff) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -431,6 +458,21 @@ func (mr *MockDiffMockRecorder) GetUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUTXO", reflect.TypeOf((*MockDiff)(nil).GetUTXO), utxoID) } +// HasExpiry mocks base method. +func (m *MockDiff) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockDiffMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockDiff)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockDiff) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -457,6 +499,18 @@ func (mr *MockDiffMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockDiff)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockDiff) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockDiffMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockDiff)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockDiff) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index b6f0c7bd6b8a..2f8ddaa4bc85 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -257,6 +257,18 @@ func (mr *MockStateMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockState)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockState) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockStateMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockState)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockState) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -412,6 +424,21 @@ func (mr *MockStateMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockState)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockState) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockStateMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockState)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockState) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -652,6 +679,21 @@ func (mr *MockStateMockRecorder) GetUptime(nodeID, subnetID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUptime", reflect.TypeOf((*MockState)(nil).GetUptime), nodeID, subnetID) } +// HasExpiry mocks base method. +func (m *MockState) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockStateMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockState)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockState) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -678,6 +720,18 @@ func (mr *MockStateMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockState)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockState) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockStateMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockState)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockState) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 849a1bb1f110..50f00e1fe8f7 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -82,6 +82,7 @@ var ( TransformedSubnetPrefix = []byte("transformedSubnet") SupplyPrefix = []byte("supply") ChainPrefix = []byte("chain") + ExpiryReplayProtectionPrefix = []byte("expiryReplayProtection") SingletonPrefix = []byte("singleton") TimestampKey = []byte("timestamp") @@ -97,6 +98,7 @@ var ( // Chain collects all methods to manage the state of the chain for block // execution. type Chain interface { + Expiry Stakers avax.UTXOAdder avax.UTXOGetter @@ -278,6 +280,8 @@ type stateBlk struct { * | '-. subnetID * | '-. list * | '-- txID -> nil + * |-. expiryReplayProtection + * | '-- timestamp + validationID -> nil * '-. singletons * |-- initializedKey -> nil * |-- blocksReindexedKey -> nil @@ -299,6 +303,10 @@ type state struct { baseDB *versiondb.Database + expiry *btree.BTreeG[ExpiryEntry] + expiryDiff *expiryDiff + expiryDB database.Database + currentStakers *baseStakers pendingStakers *baseStakers @@ -610,6 +618,10 @@ func New( blockCache: blockCache, blockDB: prefixdb.New(BlockPrefix, baseDB), + expiry: btree.NewG(defaultTreeDegree, ExpiryEntry.Less), + expiryDiff: newExpiryDiff(), + expiryDB: prefixdb.New(ExpiryReplayProtectionPrefix, baseDB), + currentStakers: newBaseStakers(), pendingStakers: newBaseStakers(), @@ -684,6 +696,27 @@ func New( return s, nil } +func (s *state) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + return s.expiryDiff.getExpiryIterator( + iterator.FromTree(s.expiry), + ), nil +} + +func (s *state) HasExpiry(entry ExpiryEntry) (bool, error) { + if has, modified := s.expiryDiff.modified[entry]; modified { + return has, nil + } + return s.expiry.Has(entry), nil +} + +func (s *state) PutExpiry(entry ExpiryEntry) { + s.expiryDiff.PutExpiry(entry) +} + +func (s *state) DeleteExpiry(entry ExpiryEntry) { + s.expiryDiff.DeleteExpiry(entry) +} + func (s *state) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { return s.currentStakers.GetValidator(subnetID, nodeID) } @@ -1336,6 +1369,7 @@ func (s *state) syncGenesis(genesisBlk block.Block, genesis *genesis.Genesis) er func (s *state) load() error { return errors.Join( s.loadMetadata(), + s.loadExpiry(), s.loadCurrentValidators(), s.loadPendingValidators(), s.initValidatorSets(), @@ -1407,6 +1441,23 @@ func (s *state) loadMetadata() error { return nil } +func (s *state) loadExpiry() error { + it := s.expiryDB.NewIterator() + defer it.Release() + + for it.Next() { + key := it.Key() + + var entry ExpiryEntry + if err := entry.Unmarshal(key); err != nil { + return fmt.Errorf("failed to unmarshal ExpiryEntry during load: %w", err) + } + s.expiry.ReplaceOrInsert(entry) + } + + return nil +} + func (s *state) loadCurrentValidators() error { s.currentStakers = newBaseStakers() @@ -1705,6 +1756,7 @@ func (s *state) write(updateValidators bool, height uint64) error { return errors.Join( s.writeBlocks(), + s.writeExpiry(), s.writeCurrentStakers(updateValidators, height, codecVersion), s.writePendingStakers(), s.WriteValidatorMetadata(s.currentValidatorList, s.currentSubnetValidatorList, codecVersion), // Must be called after writeCurrentStakers @@ -1723,6 +1775,7 @@ func (s *state) write(updateValidators bool, height uint64) error { func (s *state) Close() error { return errors.Join( + s.expiryDB.Close(), s.pendingSubnetValidatorBaseDB.Close(), s.pendingSubnetDelegatorBaseDB.Close(), s.pendingDelegatorBaseDB.Close(), @@ -1929,6 +1982,28 @@ func (s *state) GetBlockIDAtHeight(height uint64) (ids.ID, error) { return blkID, nil } +func (s *state) writeExpiry() error { + for entry, isAdded := range s.expiryDiff.modified { + var ( + key = entry.Marshal() + err error + ) + if isAdded { + s.expiry.ReplaceOrInsert(entry) + err = s.expiryDB.Put(key, nil) + } else { + s.expiry.Delete(entry) + err = s.expiryDB.Delete(key) + } + if err != nil { + return err + } + } + + s.expiryDiff = newExpiryDiff() + return nil +} + func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecVersion uint16) error { for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { delete(s.currentStakers.validatorDiffs, subnetID) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 515794364d71..a0dcd0cb7dcc 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -1577,3 +1577,35 @@ func TestGetFeeStateErrors(t *testing.T) { }) } } + +// Verify that committing the state writes the expiry changes to the database +// and that loading the state fetches the expiry from the database. +func TestStateExpiryCommitAndLoad(t *testing.T) { + require := require.New(t) + + db := memdb.New() + s := newTestState(t, db) + + // Populate an entry. + expiry := ExpiryEntry{ + Timestamp: 1, + } + s.PutExpiry(expiry) + require.NoError(s.Commit()) + + // Verify that the entry was written and loaded correctly. + s = newTestState(t, db) + has, err := s.HasExpiry(expiry) + require.NoError(err) + require.True(has) + + // Delete an entry. + s.DeleteExpiry(expiry) + require.NoError(s.Commit()) + + // Verify that the entry was deleted correctly. + s = newTestState(t, db) + has, err = s.HasExpiry(expiry) + require.NoError(err) + require.False(has) +}