diff --git a/network/dag/interface.go b/network/dag/interface.go index 6f70d8c349..8b96036091 100644 --- a/network/dag/interface.go +++ b/network/dag/interface.go @@ -41,6 +41,7 @@ var ErrPayloadNotFound = errors.New("payload not found") type State interface { core.Diagnosable core.Migratable + core.Configurable // WritePayload writes contents for the specified payload, identified by the given hash. // It also calls observers and therefore requires the transaction. diff --git a/network/dag/mock.go b/network/dag/mock.go index 073c18889e..c244959af3 100644 --- a/network/dag/mock.go +++ b/network/dag/mock.go @@ -57,6 +57,20 @@ func (mr *MockStateMockRecorder) Add(ctx, transactions, payload any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockState)(nil).Add), ctx, transactions, payload) } +// Configure mocks base method. +func (m *MockState) Configure(config core.ServerConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Configure", config) + ret0, _ := ret[0].(error) + return ret0 +} + +// Configure indicates an expected call of Configure. +func (mr *MockStateMockRecorder) Configure(config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Configure", reflect.TypeOf((*MockState)(nil).Configure), config) +} + // CorrectStateDetected mocks base method. func (m *MockState) CorrectStateDetected() { m.ctrl.T.Helper() diff --git a/network/dag/state.go b/network/dag/state.go index 0df79b02f6..81c2edf519 100644 --- a/network/dag/state.go +++ b/network/dag/state.go @@ -234,7 +234,7 @@ func (s *state) updateState(tx stoabs.WriteTx, transaction Transaction) error { } func (s *state) loadState(ctx context.Context) { - if err := s.db.Read(ctx, func(tx stoabs.ReadTx) error { + err := s.db.Read(ctx, func(tx stoabs.ReadTx) error { s.lamportClockHigh.Store(s.graph.getHighestClockValue(tx)) if err := s.xorTree.read(tx); err != nil { return fmt.Errorf("failed to read xorTree: %w", err) @@ -243,10 +243,9 @@ func (s *state) loadState(ctx context.Context) { return fmt.Errorf("failed to read ibltTree: %w", err) } return nil - }); err != nil { - log.Logger(). - WithError(err). - Errorf("Failed to load the XOR and IBLT trees") + }) + if err != nil { + log.Logger().WithError(err).Errorf("Failed to load the XOR and IBLT trees") } log.Logger().Trace("Loaded the XOR and IBLT trees") } @@ -394,6 +393,12 @@ func (s *state) CorrectStateDetected() { s.xorTreeRepair.stateOK() } +func (s *state) Configure(_ core.ServerConfig) error { + // state must be loaded before any migration takes place + s.loadState(context.Background()) + return nil +} + func (s *state) Shutdown() error { if s.transactionCount != nil { prometheus.Unregister(s.transactionCount) @@ -405,8 +410,6 @@ func (s *state) Shutdown() error { } func (s *state) Start() error { - s.loadState(context.Background()) - err := s.db.Read(context.Background(), func(tx stoabs.ReadTx) error { currentTXCount := s.graph.getNumberOfTransactions(tx) s.transactionCount.Add(float64(currentTXCount)) diff --git a/network/network.go b/network/network.go index 9a4feb6b4d..b573a71345 100644 --- a/network/network.go +++ b/network/network.go @@ -176,6 +176,11 @@ func (n *Network) Configure(config core.ServerConfig) error { if n.state, err = dag.NewState(dagStore, dag.NewPrevTransactionsVerifier(), dag.NewTransactionSignatureVerifier(nutsKeyResolver)); err != nil { return fmt.Errorf("failed to configure state: %w", err) } + // load state + err = n.state.Configure(core.ServerConfig{}) + if err != nil { + return err + } n.strictMode = config.Strictmode n.peerID = transport.PeerID(uuid.New().String()) diff --git a/network/network_test.go b/network/network_test.go index 04c1c1ea47..9a41c23c54 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -357,7 +357,9 @@ func TestNetwork_Configure(t *testing.T) { prov := storage.NewMockProvider(ctrl) ctx.network.storeProvider = prov ctx.network.connectionManager = nil - prov.EXPECT().GetKVStore(gomock.Any(), gomock.Any()) + store := stoabs.NewMockKVStore(ctrl) + store.EXPECT().Read(gomock.Any(), gomock.Any()) + prov.EXPECT().GetKVStore(gomock.Any(), gomock.Any()).Return(store, nil) prov.EXPECT().GetKVStore(gomock.Any(), gomock.Any()).Return(nil, errors.New("failed")) err := ctx.network.Configure(core.TestServerConfig(func(config *core.ServerConfig) {