From c939dc1ee9976fc144d46720a4639b0743c33f8e Mon Sep 17 00:00:00 2001 From: Wout Slakhorst Date: Mon, 13 May 2024 15:31:11 +0200 Subject: [PATCH] Change tag in discovery service to simple lamport clock value (int) (#3098) --- discovery/api/server/api.go | 33 +- discovery/api/server/api_test.go | 45 ++- discovery/api/server/client/http.go | 20 +- discovery/api/server/client/http_test.go | 49 ++- discovery/api/server/client/interface.go | 8 +- discovery/api/server/client/mock.go | 12 +- discovery/api/server/client/types.go | 6 +- discovery/api/server/generated.go | 12 +- discovery/client.go | 8 +- discovery/client_test.go | 25 +- discovery/interface.go | 61 +--- discovery/interface_test.go | 61 ---- discovery/mock.go | 20 +- discovery/module.go | 89 ++++-- discovery/module_test.go | 302 +++++++++++------- discovery/store.go | 143 +++------ discovery/store_test.go | 102 +++--- docs/_static/discovery/server.yaml | 24 +- .../sql_migrations/002_discoveryservice.sql | 15 +- 19 files changed, 509 insertions(+), 526 deletions(-) delete mode 100644 discovery/interface_test.go diff --git a/discovery/api/server/api.go b/discovery/api/server/api.go index 09531ad6e3..2b9b66bf99 100644 --- a/discovery/api/server/api.go +++ b/discovery/api/server/api.go @@ -41,8 +41,6 @@ type Wrapper struct { func (w *Wrapper) ResolveStatusCode(err error) int { switch { - case errors.Is(err, discovery.ErrServerModeDisabled): - return http.StatusBadRequest case errors.Is(err, discovery.ErrInvalidPresentation): return http.StatusBadRequest default: @@ -63,27 +61,36 @@ func (w *Wrapper) Routes(router core.EchoRouter) { })) } -func (w *Wrapper) GetPresentations(_ context.Context, request GetPresentationsRequestObject) (GetPresentationsResponseObject, error) { - var tag *discovery.Tag - if request.Params.Tag != nil { - // *string to *Tag - tag = new(discovery.Tag) - *tag = discovery.Tag(*request.Params.Tag) +func (w *Wrapper) GetPresentations(ctx context.Context, request GetPresentationsRequestObject) (GetPresentationsResponseObject, error) { + var timestamp int + if request.Params.Timestamp != nil { + timestamp = *request.Params.Timestamp } - presentations, newTag, err := w.Server.Get(request.ServiceID, tag) + + presentations, newTimestamp, err := w.Server.Get(contextWithForwardedHost(ctx), request.ServiceID, timestamp) if err != nil { return nil, err } return GetPresentations200JSONResponse{ - Entries: presentations, - Tag: string(*newTag), + Entries: presentations, + Timestamp: newTimestamp, }, nil } -func (w *Wrapper) RegisterPresentation(_ context.Context, request RegisterPresentationRequestObject) (RegisterPresentationResponseObject, error) { - err := w.Server.Register(request.ServiceID, *request.Body) +func (w *Wrapper) RegisterPresentation(ctx context.Context, request RegisterPresentationRequestObject) (RegisterPresentationResponseObject, error) { + err := w.Server.Register(contextWithForwardedHost(ctx), request.ServiceID, *request.Body) if err != nil { return nil, err } return RegisterPresentation201Response{}, nil } + +func contextWithForwardedHost(ctx context.Context) context.Context { + // cast context to echo.Context + echoCtx := ctx.Value("echo.Context") + if echoCtx != nil { + // forward X-Forwarded-Host header via context + ctx = context.WithValue(ctx, discovery.XForwardedHostContextKey{}, echoCtx.(echo.Context).Request().Header.Get("X-Forwarded-Host")) + } + return ctx +} diff --git a/discovery/api/server/api_test.go b/discovery/api/server/api_test.go index 809258d414..5a5cc8cf3b 100644 --- a/discovery/api/server/api_test.go +++ b/discovery/api/server/api_test.go @@ -19,8 +19,8 @@ package server import ( + "context" "errors" - "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/discovery" "github.com/stretchr/testify/assert" @@ -32,58 +32,56 @@ import ( const serviceID = "wonderland" -var subjectDID = did.MustParseDID("did:web:example.com") - func TestWrapper_GetPresentations(t *testing.T) { - t.Run("no tag", func(t *testing.T) { - latestTag := discovery.Tag("latest") + lastTimestamp := 1 + presentations := map[string]vc.VerifiablePresentation{} + ctx := context.Background() + t.Run("no timestamp", func(t *testing.T) { test := newMockContext(t) - presentations := []vc.VerifiablePresentation{} - test.server.EXPECT().Get(serviceID, nil).Return(presentations, &latestTag, nil) + test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(presentations, lastTimestamp, nil) - response, err := test.wrapper.GetPresentations(nil, GetPresentationsRequestObject{ServiceID: serviceID}) + response, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ServiceID: serviceID}) require.NoError(t, err) require.IsType(t, GetPresentations200JSONResponse{}, response) - assert.Equal(t, latestTag, discovery.Tag(response.(GetPresentations200JSONResponse).Tag)) + assert.Equal(t, lastTimestamp, response.(GetPresentations200JSONResponse).Timestamp) assert.Equal(t, presentations, response.(GetPresentations200JSONResponse).Entries) }) - t.Run("with tag", func(t *testing.T) { - givenTag := discovery.Tag("given") - latestTag := discovery.Tag("latest") + t.Run("with timestamp", func(t *testing.T) { + givenTimestamp := 1 test := newMockContext(t) - presentations := []vc.VerifiablePresentation{} - test.server.EXPECT().Get(serviceID, &givenTag).Return(presentations, &latestTag, nil) + test.server.EXPECT().Get(gomock.Any(), serviceID, 1).Return(presentations, lastTimestamp, nil) - response, err := test.wrapper.GetPresentations(nil, GetPresentationsRequestObject{ + response, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ ServiceID: serviceID, Params: GetPresentationsParams{ - Tag: (*string)(&givenTag), + Timestamp: &givenTimestamp, }, }) require.NoError(t, err) require.IsType(t, GetPresentations200JSONResponse{}, response) - assert.Equal(t, latestTag, discovery.Tag(response.(GetPresentations200JSONResponse).Tag)) + assert.Equal(t, lastTimestamp, response.(GetPresentations200JSONResponse).Timestamp) assert.Equal(t, presentations, response.(GetPresentations200JSONResponse).Entries) }) t.Run("error", func(t *testing.T) { test := newMockContext(t) - test.server.EXPECT().Get(serviceID, nil).Return(nil, nil, errors.New("foo")) + test.server.EXPECT().Get(gomock.Any(), serviceID, 0).Return(nil, 0, errors.New("foo")) - _, err := test.wrapper.GetPresentations(nil, GetPresentationsRequestObject{ServiceID: serviceID}) + _, err := test.wrapper.GetPresentations(ctx, GetPresentationsRequestObject{ServiceID: serviceID}) assert.Error(t, err) }) } func TestWrapper_RegisterPresentation(t *testing.T) { + ctx := context.Background() t.Run("ok", func(t *testing.T) { test := newMockContext(t) presentation := vc.VerifiablePresentation{} - test.server.EXPECT().Register(serviceID, presentation).Return(nil) + test.server.EXPECT().Register(gomock.Any(), serviceID, presentation).Return(nil) - response, err := test.wrapper.RegisterPresentation(nil, RegisterPresentationRequestObject{ + response, err := test.wrapper.RegisterPresentation(ctx, RegisterPresentationRequestObject{ ServiceID: serviceID, Body: &presentation, }) @@ -94,9 +92,9 @@ func TestWrapper_RegisterPresentation(t *testing.T) { t.Run("error", func(t *testing.T) { test := newMockContext(t) presentation := vc.VerifiablePresentation{} - test.server.EXPECT().Register(serviceID, presentation).Return(discovery.ErrInvalidPresentation) + test.server.EXPECT().Register(gomock.Any(), serviceID, presentation).Return(discovery.ErrInvalidPresentation) - _, err := test.wrapper.RegisterPresentation(nil, RegisterPresentationRequestObject{ + _, err := test.wrapper.RegisterPresentation(ctx, RegisterPresentationRequestObject{ ServiceID: serviceID, Body: &presentation, }) @@ -107,7 +105,6 @@ func TestWrapper_RegisterPresentation(t *testing.T) { func TestWrapper_ResolveStatusCode(t *testing.T) { expected := map[error]int{ - discovery.ErrServerModeDisabled: http.StatusBadRequest, discovery.ErrInvalidPresentation: http.StatusBadRequest, errors.New("foo"): http.StatusInternalServerError, } diff --git a/discovery/api/server/client/http.go b/discovery/api/server/client/http.go index 1c3be3dd79..0fb18e05fb 100644 --- a/discovery/api/server/client/http.go +++ b/discovery/api/server/client/http.go @@ -54,6 +54,7 @@ func (h DefaultHTTPClient) Register(ctx context.Context, serviceEndpointURL stri return err } httpRequest.Header.Set("Content-Type", "application/json") + httpRequest.Header.Set("X-Forwarded-Host", httpRequest.Host) // prevent cycles httpResponse, err := h.client.Do(httpRequest) if err != nil { return fmt.Errorf("failed to invoke remote Discovery Service (url=%s): %w", serviceEndpointURL, err) @@ -65,29 +66,28 @@ func (h DefaultHTTPClient) Register(ctx context.Context, serviceEndpointURL stri return nil } -func (h DefaultHTTPClient) Get(ctx context.Context, serviceEndpointURL string, tag string) ([]vc.VerifiablePresentation, string, error) { +func (h DefaultHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) { httpRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, serviceEndpointURL, nil) - if tag != "" { - httpRequest.URL.RawQuery = url.Values{"tag": []string{tag}}.Encode() - } + httpRequest.URL.RawQuery = url.Values{"timestamp": []string{fmt.Sprintf("%d", timestamp)}}.Encode() if err != nil { - return nil, "", err + return nil, 0, err } + httpRequest.Header.Set("X-Forwarded-Host", httpRequest.Host) // prevent cycles httpResponse, err := h.client.Do(httpRequest) if err != nil { - return nil, "", fmt.Errorf("failed to invoke remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, 0, fmt.Errorf("failed to invoke remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } defer httpResponse.Body.Close() if err := core.TestResponseCode(200, httpResponse); err != nil { - return nil, "", fmt.Errorf("non-OK response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, 0, fmt.Errorf("non-OK response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } responseData, err := io.ReadAll(httpResponse.Body) if err != nil { - return nil, "", fmt.Errorf("failed to read response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, 0, fmt.Errorf("failed to read response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } var result PresentationsResponse if err := json.Unmarshal(responseData, &result); err != nil { - return nil, "", fmt.Errorf("failed to unmarshal response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) + return nil, 0, fmt.Errorf("failed to unmarshal response from remote Discovery Service (url=%s): %w", serviceEndpointURL, err) } - return result.Entries, result.Tag, nil + return result.Entries, result.Timestamp, nil } diff --git a/discovery/api/server/client/http_test.go b/discovery/api/server/client/http_test.go index 108e90a5db..9387dc5fd6 100644 --- a/discovery/api/server/client/http_test.go +++ b/discovery/api/server/client/http_test.go @@ -24,8 +24,10 @@ import ( "github.com/nuts-foundation/go-did/vc" testHTTP "github.com/nuts-foundation/nuts-node/test/http" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -61,46 +63,61 @@ func TestHTTPInvoker_Get(t *testing.T) { vp := vc.VerifiablePresentation{ Context: []ssi.URI{ssi.MustParseURI("https://www.w3.org/2018/credentials/v1")}, } - const clientTag = "client-tag" - const serverTag = "server-tag" - t.Run("no tag from client", func(t *testing.T) { + + t.Run("no timestamp from client", func(t *testing.T) { handler := &testHTTP.Handler{StatusCode: http.StatusOK} handler.ResponseData = map[string]interface{}{ - "entries": []interface{}{vp}, - "tag": serverTag, + "entries": map[string]interface{}{"1": vp}, + "timestamp": 1, } server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - presentations, tag, err := client.Get(context.Background(), server.URL, "") + presentations, timestamp, err := client.Get(context.Background(), server.URL, 0) assert.NoError(t, err) assert.Len(t, presentations, 1) - assert.Empty(t, handler.RequestQuery.Get("tag")) - assert.Equal(t, serverTag, tag) + assert.Equal(t, "0", handler.RequestQuery.Get("timestamp")) + assert.Equal(t, 1, timestamp) }) - t.Run("tag provided by client", func(t *testing.T) { + t.Run("timestamp provided by client", func(t *testing.T) { handler := &testHTTP.Handler{StatusCode: http.StatusOK} handler.ResponseData = map[string]interface{}{ - "entries": []interface{}{vp}, - "tag": serverTag, + "entries": map[string]interface{}{"1": vp}, + "timestamp": 1, } server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - presentations, tag, err := client.Get(context.Background(), server.URL, clientTag) + presentations, timestamp, err := client.Get(context.Background(), server.URL, 1) assert.NoError(t, err) assert.Len(t, presentations, 1) - assert.Equal(t, clientTag, handler.RequestQuery.Get("tag")) - assert.Equal(t, serverTag, tag) + assert.Equal(t, "1", handler.RequestQuery.Get("timestamp")) + assert.Equal(t, 1, timestamp) + }) + t.Run("check X-Forwarded-Host header", func(t *testing.T) { + // custom handler to check the X-Forwarded-Host header + var capturedRequest *http.Request + handler := func(writer http.ResponseWriter, request *http.Request) { + capturedRequest = request + writer.WriteHeader(http.StatusOK) + writer.Write([]byte("{}")) + } + server := httptest.NewServer(http.HandlerFunc(handler)) + client := New(false, time.Minute, server.TLS) + + _, _, err := client.Get(context.Background(), server.URL, 0) + + require.NoError(t, err) + assert.True(t, strings.HasPrefix(capturedRequest.Header.Get("X-Forwarded-Host"), "127.0.0.1")) }) t.Run("server returns invalid status code", func(t *testing.T) { handler := &testHTTP.Handler{StatusCode: http.StatusInternalServerError} server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - _, _, err := client.Get(context.Background(), server.URL, "") + _, _, err := client.Get(context.Background(), server.URL, 0) assert.ErrorContains(t, err, "non-OK response from remote Discovery Service") }) @@ -110,7 +127,7 @@ func TestHTTPInvoker_Get(t *testing.T) { server := httptest.NewServer(handler) client := New(false, time.Minute, server.TLS) - _, _, err := client.Get(context.Background(), server.URL, "") + _, _, err := client.Get(context.Background(), server.URL, 0) assert.ErrorContains(t, err, "failed to unmarshal response from remote Discovery Service") }) diff --git a/discovery/api/server/client/interface.go b/discovery/api/server/client/interface.go index e307e1bde8..24087f718f 100644 --- a/discovery/api/server/client/interface.go +++ b/discovery/api/server/client/interface.go @@ -28,8 +28,8 @@ type HTTPClient interface { // Register registers a Verifiable Presentation on the remote Discovery Service. Register(ctx context.Context, serviceEndpointURL string, presentation vc.VerifiablePresentation) error - // Get retrieves Verifiable Presentations from the remote Discovery Service, that were added since the given tag. - // If the call succeeds it returns the Verifiable Presentations and the tag that was returned by the server. - // If tag is empty, all Verifiable Presentations are retrieved. - Get(ctx context.Context, serviceEndpointURL string, tag string) ([]vc.VerifiablePresentation, string, error) + // Get retrieves Verifiable Presentations from the remote Discovery Service, that were added since the given timestamp. + // If the call succeeds it returns the Verifiable Presentations and the timestamp that was returned by the server. + // If the given timestamp is 0, all Verifiable Presentations are retrieved. + Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) } diff --git a/discovery/api/server/client/mock.go b/discovery/api/server/client/mock.go index f069999f22..2fe595a282 100644 --- a/discovery/api/server/client/mock.go +++ b/discovery/api/server/client/mock.go @@ -41,19 +41,19 @@ func (m *MockHTTPClient) EXPECT() *MockHTTPClientMockRecorder { } // Get mocks base method. -func (m *MockHTTPClient) Get(ctx context.Context, serviceEndpointURL, tag string) ([]vc.VerifiablePresentation, string, error) { +func (m *MockHTTPClient) Get(ctx context.Context, serviceEndpointURL string, timestamp int) (map[string]vc.VerifiablePresentation, int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", ctx, serviceEndpointURL, tag) - ret0, _ := ret[0].([]vc.VerifiablePresentation) - ret1, _ := ret[1].(string) + ret := m.ctrl.Call(m, "Get", ctx, serviceEndpointURL, timestamp) + ret0, _ := ret[0].(map[string]vc.VerifiablePresentation) + ret1, _ := ret[1].(int) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // Get indicates an expected call of Get. -func (mr *MockHTTPClientMockRecorder) Get(ctx, serviceEndpointURL, tag any) *gomock.Call { +func (mr *MockHTTPClientMockRecorder) Get(ctx, serviceEndpointURL, timestamp any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockHTTPClient)(nil).Get), ctx, serviceEndpointURL, tag) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockHTTPClient)(nil).Get), ctx, serviceEndpointURL, timestamp) } // Register mocks base method. diff --git a/discovery/api/server/client/types.go b/discovery/api/server/client/types.go index caec4009ad..89c17609ee 100644 --- a/discovery/api/server/client/types.go +++ b/discovery/api/server/client/types.go @@ -22,6 +22,8 @@ import "github.com/nuts-foundation/go-did/vc" // PresentationsResponse is the response for the GetPresentations endpoint. type PresentationsResponse struct { - Entries []vc.VerifiablePresentation `json:"entries"` - Tag string `json:"tag"` + // Entries contains mappings from timestamp (as string) to a VerifiablePresentation. + Entries map[string]vc.VerifiablePresentation `json:"entries"` + // Timestamp is the timestamp of the latest entry. It's not a unix timestamp but a Lamport Clock. + Timestamp int `json:"timestamp"` } diff --git a/discovery/api/server/generated.go b/discovery/api/server/generated.go index a057b53ab6..a0beae5243 100644 --- a/discovery/api/server/generated.go +++ b/discovery/api/server/generated.go @@ -35,7 +35,7 @@ type SearchResult struct { // GetPresentationsParams defines parameters for GetPresentations. type GetPresentationsParams struct { - Tag *string `form:"tag,omitempty" json:"tag,omitempty"` + Timestamp *int `form:"timestamp,omitempty" json:"timestamp,omitempty"` } // RegisterPresentationJSONRequestBody defines body for RegisterPresentation for application/json ContentType. @@ -188,9 +188,9 @@ func NewGetPresentationsRequest(server string, serviceID string, params *GetPres if params != nil { queryValues := queryURL.Query() - if params.Tag != nil { + if params.Timestamp != nil { - if queryFrag, err := runtime.StyleParamWithLocation("form", true, "tag", runtime.ParamLocationQuery, *params.Tag); err != nil { + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "timestamp", runtime.ParamLocationQuery, *params.Timestamp); err != nil { return nil, err } else if parsed, err := url.ParseQuery(queryFrag); err != nil { return nil, err @@ -534,11 +534,11 @@ func (w *ServerInterfaceWrapper) GetPresentations(ctx echo.Context) error { // Parameter object where we will unmarshal all parameters from the context var params GetPresentationsParams - // ------------- Optional query parameter "tag" ------------- + // ------------- Optional query parameter "timestamp" ------------- - err = runtime.BindQueryParameter("form", true, false, "tag", ctx.QueryParams(), ¶ms.Tag) + err = runtime.BindQueryParameter("form", true, false, "timestamp", ctx.QueryParams(), ¶ms.Timestamp) if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter tag: %s", err)) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter timestamp: %s", err)) } // Invoke the callback with all the unmarshaled arguments diff --git a/discovery/client.go b/discovery/client.go index 47d910ef86..a7ef01274b 100644 --- a/discovery/client.go +++ b/discovery/client.go @@ -205,14 +205,14 @@ func (u *clientUpdater) update(ctx context.Context) error { } func (u *clientUpdater) updateService(ctx context.Context, service ServiceDefinition) error { - currentTag, err := u.store.getTag(service.ID) + currentTimestamp, err := u.store.getTimestamp(service.ID) if err != nil { return err } log.Logger(). WithField("discoveryService", service.ID). - Tracef("Checking for new Verifiable Presentations from Discovery Service (tag: %s)", currentTag) - presentations, tag, err := u.client.Get(ctx, service.Endpoint, string(currentTag)) + Tracef("Checking for new Verifiable Presentations from Discovery Service (timestamp: %d)", currentTimestamp) + presentations, serverTimestamp, err := u.client.Get(ctx, service.Endpoint, currentTimestamp) if err != nil { return fmt.Errorf("failed to get presentations from discovery service (id=%s): %w", service.ID, err) } @@ -221,7 +221,7 @@ func (u *clientUpdater) updateService(ctx context.Context, service ServiceDefini log.Logger().WithError(err).Warnf("Presentation verification failed, not adding it (service=%s, id=%s)", service.ID, presentation.ID) continue } - if err := u.store.add(service.ID, presentation, Tag(tag)); err != nil { + if err := u.store.add(service.ID, presentation, serverTimestamp); err != nil { return fmt.Errorf("failed to store presentation (service=%s, id=%s): %w", service.ID, presentation.ID, err) } log.Logger(). diff --git a/discovery/client_test.go b/discovery/client_test.go index 69634d037b..24df7dc227 100644 --- a/discovery/client_test.go +++ b/discovery/client_test.go @@ -151,7 +151,7 @@ func Test_scheduledRegistrationManager_deregister(t *testing.T) { mockVCR.EXPECT().Wallet().Return(wallet).AnyTimes() store := setupStore(t, storageEngine.GetSQLDatabase()) manager := newRegistrationManager(testDefinitions(), store, invoker, mockVCR) - require.NoError(t, store.add(testServiceID, vpAlice, "taggy")) + require.NoError(t, store.add(testServiceID, vpAlice, 1)) err := manager.deactivate(audit.TestContext(), testServiceID, aliceDID) @@ -167,7 +167,7 @@ func Test_scheduledRegistrationManager_deregister(t *testing.T) { mockVCR.EXPECT().Wallet().Return(wallet).AnyTimes() store := setupStore(t, storageEngine.GetSQLDatabase()) manager := newRegistrationManager(testDefinitions(), store, invoker, mockVCR) - require.NoError(t, store.add(testServiceID, vpAlice, "taggy")) + require.NoError(t, store.add(testServiceID, vpAlice, 1)) err := manager.deactivate(audit.TestContext(), testServiceID, aliceDID) @@ -221,10 +221,9 @@ func Test_scheduledRegistrationManager_refresh(t *testing.T) { func Test_clientUpdater_updateService(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - store, err := newSQLStore(storageEngine.GetSQLDatabase(), testDefinitions(), nil) + store, err := newSQLStore(storageEngine.GetSQLDatabase(), testDefinitions()) require.NoError(t, err) ctx := context.Background() - newTag := "test" serviceDefinition := testDefinitions()[testServiceID] t.Run("no updates", func(t *testing.T) { @@ -233,7 +232,7 @@ func Test_clientUpdater_updateService(t *testing.T) { httpClient := client.NewMockHTTPClient(ctrl) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, testDefinitions()[testServiceID].Endpoint, "").Return([]vc.VerifiablePresentation{}, newTag, nil) + httpClient.EXPECT().Get(ctx, testDefinitions()[testServiceID].Endpoint, 0).Return(map[string]vc.VerifiablePresentation{}, 0, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -245,7 +244,7 @@ func Test_clientUpdater_updateService(t *testing.T) { httpClient := client.NewMockHTTPClient(ctrl) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, "").Return([]vc.VerifiablePresentation{vpAlice}, newTag, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -262,7 +261,7 @@ func Test_clientUpdater_updateService(t *testing.T) { return nil }, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, "").Return([]vc.VerifiablePresentation{vpAlice, vpBob}, newTag, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, 2, nil) err := updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -275,15 +274,15 @@ func Test_clientUpdater_updateService(t *testing.T) { require.NoError(t, err) require.False(t, exists) }) - t.Run("pass tag", func(t *testing.T) { + t.Run("pass timestamp", func(t *testing.T) { resetStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - _, err := store.updateTag(store.db, testServiceID, "test") + err := store.setTimestamp(store.db, testServiceID, 1) require.NoError(t, err) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) - httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, "test").Return([]vc.VerifiablePresentation{vpAlice}, newTag, nil) + httpClient.EXPECT().Get(ctx, serviceDefinition.Endpoint, 1).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil) err = updater.updateService(ctx, testDefinitions()[testServiceID]) @@ -298,8 +297,8 @@ func Test_clientUpdater_update(t *testing.T) { store := setupStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return([]vc.VerifiablePresentation{}, "test", nil) - httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, "", errors.New("test")) + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, 0, nil) + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, 0, errors.New("test")) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) err := updater.update(context.Background()) @@ -312,7 +311,7 @@ func Test_clientUpdater_update(t *testing.T) { store := setupStore(t, storageEngine.GetSQLDatabase()) ctrl := gomock.NewController(t) httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return([]vc.VerifiablePresentation{}, "test", nil).MinTimes(2) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]vc.VerifiablePresentation{}, 0, nil).MinTimes(2) updater := newClientUpdater(testDefinitions(), store, alwaysOkVerifier, httpClient) err := updater.update(context.Background()) diff --git a/discovery/interface.go b/discovery/interface.go index e9fd655633..85322e1466 100644 --- a/discovery/interface.go +++ b/discovery/interface.go @@ -23,56 +23,8 @@ import ( "errors" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/go-did/vc" - "strconv" - "strings" ) -// Tag is value that references a point in the list. -// It is used by clients to request new entries since their last query. -// It is opaque for clients: they should not try to interpret it. -// The server who issued the tag can interpret it as Lamport timestamp. -type Tag string - -// Timestamp decodes the Tag into a Timestamp, which is a monotonically increasing integer value (Lamport timestamp). -// Tags should only be decoded by the server who issued it, so the server should provide the stored tag prefix. -// The tag prefix is a random value that is generated when the service is created. -// It is not a secret; it only makes sure clients receive the complete presentation list when they switch servers for a specific Discovery Service: -// servers return the complete list when the client passes a timestamp the server can't decode. -func (t Tag) Timestamp(tagPrefix string) *Timestamp { - trimmed := strings.TrimPrefix(string(t), tagPrefix) - if len(trimmed) == len(string(t)) { - // Invalid tag prefix - return nil - } - result, err := strconv.ParseUint(trimmed, 10, 64) - if err != nil { - // Not a number - return nil - } - lamport := Timestamp(result) - return &lamport -} - -// Empty returns true if the Tag is empty. -func (t Tag) Empty() bool { - return len(t) == 0 -} - -// Timestamp is the interpreted Tag. -// It's implemented as lamport timestamp (https://en.wikipedia.org/wiki/Lamport_timestamp); -// it is incremented when a new entry is added to the list. -// Pass 0 to start at the beginning of the list. -type Timestamp uint64 - -// Tag returns the Timestamp as Tag. -func (l Timestamp) Tag(serviceSeed string) Tag { - return Tag(serviceSeed + strconv.FormatUint(uint64(l), 10)) -} - -func (l Timestamp) Increment() Timestamp { - return l + 1 -} - // ErrServiceNotFound is returned when a service (ID) is not found in the discovery service. var ErrServiceNotFound = errors.New("discovery service not found") @@ -87,9 +39,11 @@ var ErrPresentationRegistrationFailed = errors.New("registration of Verifiable P type Server interface { // Register registers a presentation on the given Discovery Service. // If the presentation is not valid, or it does not conform to the Service ServiceDefinition, it returns an error. - Register(serviceID string, presentation vc.VerifiablePresentation) error - // Get retrieves the presentations for the given service, starting at the given timestamp. - Get(serviceID string, startAt *Tag) ([]vc.VerifiablePresentation, *Tag, error) + // If the node is not configured as server for the given serviceID, the call will be forwarded to the configured server. + Register(context context.Context, serviceID string, presentation vc.VerifiablePresentation) error + // Get retrieves the presentations for the given service, starting from the given timestamp. + // If the node is not configured as server for the given serviceID, the call will be forwarded to the configured server. + Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) } // Client defines the API for Discovery Clients. @@ -113,7 +67,7 @@ type Client interface { Services() []ServiceDefinition // GetServiceActivation returns the activation status of a DID on a Discovery Service. - // The boolean indicates whether the DID is acitvated on the Discovery Service (ActivateServiceForDID() has been called). + // The boolean indicates whether the DID is activated on the Discovery Service (ActivateServiceForDID() has been called). // It also returns the Verifiable Presentation that is registered on the Discovery Service, if any. GetServiceActivation(ctx context.Context, serviceID string, subjectDID did.DID) (bool, *vc.VerifiablePresentation, error) } @@ -129,3 +83,6 @@ type SearchResult struct { } type presentationVerifier func(definition ServiceDefinition, presentation vc.VerifiablePresentation) error + +// XForwardedHostContextKey is the context key for the X-Forwarded-Host header. +type XForwardedHostContextKey struct{} diff --git a/discovery/interface_test.go b/discovery/interface_test.go deleted file mode 100644 index 3ac1a0ac75..0000000000 --- a/discovery/interface_test.go +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (C) 2023 Nuts community - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package discovery - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTag_Empty(t *testing.T) { - t.Run("empty", func(t *testing.T) { - assert.True(t, Tag("").Empty()) - }) - t.Run("not empty", func(t *testing.T) { - assert.False(t, Tag("not empty").Empty()) - }) -} - -func TestTag_Timestamp(t *testing.T) { - t.Run("invalid tag prefix", func(t *testing.T) { - assert.Nil(t, Tag("invalid tag prefix").Timestamp("tag prefix")) - }) - t.Run("not a number", func(t *testing.T) { - assert.Nil(t, Tag("tag prefix").Timestamp("tag prefixnot a number")) - }) - t.Run("invalid uint64", func(t *testing.T) { - assert.Nil(t, Tag("tag prefix").Timestamp("tag prefix")) - }) - t.Run("valid (small number)", func(t *testing.T) { - assert.Equal(t, Timestamp(1), *Tag("tag prefix1").Timestamp("tag prefix")) - }) - t.Run("valid (large number)", func(t *testing.T) { - assert.Equal(t, Timestamp(1234567890), *Tag("tag prefix1234567890").Timestamp("tag prefix")) - }) -} - -func TestTimestamp_Tag(t *testing.T) { - assert.Equal(t, Tag("tag prefix1"), Timestamp(1).Tag("tag prefix")) -} - -func TestTimestamp_Increment(t *testing.T) { - assert.Equal(t, Timestamp(1), Timestamp(0).Increment()) - assert.Equal(t, Timestamp(2), Timestamp(1).Increment()) - assert.Equal(t, Timestamp(1234567890), Timestamp(1234567889).Increment()) -} diff --git a/discovery/mock.go b/discovery/mock.go index cdd851acb8..80cc1dede0 100644 --- a/discovery/mock.go +++ b/discovery/mock.go @@ -42,33 +42,33 @@ func (m *MockServer) EXPECT() *MockServerMockRecorder { } // Get mocks base method. -func (m *MockServer) Get(serviceID string, startAt *Tag) ([]vc.VerifiablePresentation, *Tag, error) { +func (m *MockServer) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", serviceID, startAt) - ret0, _ := ret[0].([]vc.VerifiablePresentation) - ret1, _ := ret[1].(*Tag) + ret := m.ctrl.Call(m, "Get", context, serviceID, startAfter) + ret0, _ := ret[0].(map[string]vc.VerifiablePresentation) + ret1, _ := ret[1].(int) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // Get indicates an expected call of Get. -func (mr *MockServerMockRecorder) Get(serviceID, startAt any) *gomock.Call { +func (mr *MockServerMockRecorder) Get(context, serviceID, startAfter any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockServer)(nil).Get), serviceID, startAt) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockServer)(nil).Get), context, serviceID, startAfter) } // Register mocks base method. -func (m *MockServer) Register(serviceID string, presentation vc.VerifiablePresentation) error { +func (m *MockServer) Register(context context.Context, serviceID string, presentation vc.VerifiablePresentation) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Register", serviceID, presentation) + ret := m.ctrl.Call(m, "Register", context, serviceID, presentation) ret0, _ := ret[0].(error) return ret0 } // Register indicates an expected call of Register. -func (mr *MockServerMockRecorder) Register(serviceID, presentation any) *gomock.Call { +func (mr *MockServerMockRecorder) Register(context, serviceID, presentation any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockServer)(nil).Register), serviceID, presentation) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockServer)(nil).Register), context, serviceID, presentation) } // MockClient is a mock of Client interface. diff --git a/discovery/module.go b/discovery/module.go index 06674477e3..17512ccdf0 100644 --- a/discovery/module.go +++ b/discovery/module.go @@ -33,6 +33,7 @@ import ( "github.com/nuts-foundation/nuts-node/vcr" "github.com/nuts-foundation/nuts-node/vcr/credential" "github.com/nuts-foundation/nuts-node/vdr/management" + "net/url" "os" "path" "strings" @@ -42,10 +43,6 @@ import ( const ModuleName = "Discovery" -// ErrServerModeDisabled is returned when a client invokes a Discovery Server (Register or Get) operation on the node, -// for a Discovery Service which it doesn't serve. -var ErrServerModeDisabled = errors.New("node is not a discovery server for this service") - // ErrInvalidPresentation is returned when a client tries to register a Verifiable Presentation that is invalid. var ErrInvalidPresentation = errors.New("presentation is invalid for registration") @@ -58,6 +55,7 @@ var ( errRetractionReferencesUnknownPresentation = errors.New("retraction presentation refers to a non-existing presentation") errRetractionContainsCredentials = errors.New("retraction presentation must not contain credentials") errInvalidRetractionJTIClaim = errors.New("invalid/missing 'retract_jti' claim for retraction presentation") + errCyclicForwardingDetected = errors.New("cyclic forwarding detected") ) var _ core.Injectable = &Module{} @@ -118,11 +116,11 @@ func (m *Module) Configure(serverConfig core.ServerConfig) error { if len(m.config.Server.IDs) > 0 { // Get the definitions that are enabled for this server serverDefinitions := make(map[string]ServiceDefinition) - for _, definitionID := range m.config.Server.IDs { - if definition, exists := m.allDefinitions[definitionID]; !exists { - return fmt.Errorf("service definition '%s' not found", definitionID) + for _, serviceID := range m.config.Server.IDs { + if service, exists := m.allDefinitions[serviceID]; !exists { + return fmt.Errorf("service definition '%s' not found", serviceID) } else { - serverDefinitions[definitionID] = definition + serverDefinitions[serviceID] = service } } m.serverDefinitions = serverDefinitions @@ -133,7 +131,7 @@ func (m *Module) Configure(serverConfig core.ServerConfig) error { func (m *Module) Start() error { var err error - m.store, err = newSQLStore(m.storageInstance.GetSQLDatabase(), m.allDefinitions, m.serverDefinitions) + m.store, err = newSQLStore(m.storageInstance.GetSQLDatabase(), m.allDefinitions) if err != nil { return err } @@ -165,16 +163,31 @@ func (m *Module) Config() interface{} { // Register is a Discovery Server function that registers a presentation on the given Discovery Service. // See interface.go for more information. -func (m *Module) Register(serviceID string, presentation vc.VerifiablePresentation) error { +func (m *Module) Register(context context.Context, serviceID string, presentation vc.VerifiablePresentation) error { // First, simple sanity checks - definition, isServer := m.serverDefinitions[serviceID] + _, isServer := m.serverDefinitions[serviceID] if !isServer { - return ErrServerModeDisabled + // forward to configured server + service, exists := m.allDefinitions[serviceID] + if !exists { + return ErrServiceNotFound + } + + // check If X-Forwarded-Host header is set, if set it must not be the same as service.Endpoint + if cycleDetected(context, service) { + return errCyclicForwardingDetected + } + + // forward to configured server + log.Logger().Infof("Forwarding Register request to configured server (service=%s)", serviceID) + return m.httpClient.Register(context, service.Endpoint, presentation) } + definition := m.allDefinitions[serviceID] if err := m.verifyRegistration(definition, presentation); err != nil { return err } - return m.store.add(definition.ID, presentation, "") + + return m.store.add(serviceID, presentation, 0) } func (m *Module) verifyRegistration(definition ServiceDefinition, presentation vc.VerifiablePresentation) error { @@ -270,13 +283,53 @@ func (m *Module) validateRetraction(serviceID string, presentation vc.Verifiable return nil } -// Get is a Discovery Server function that retrieves the presentations for the given service, starting at the given tag. +// Get is a Discovery Server function that retrieves the presentations for the given service, starting at timestamp+1. // See interface.go for more information. -func (m *Module) Get(serviceID string, tag *Tag) ([]vc.VerifiablePresentation, *Tag, error) { - if _, exists := m.serverDefinitions[serviceID]; !exists { - return nil, nil, ErrServerModeDisabled +func (m *Module) Get(context context.Context, serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { + _, exists := m.serverDefinitions[serviceID] + if !exists { + // forward to configured server + service, exists := m.allDefinitions[serviceID] + if !exists { + return nil, 0, ErrServiceNotFound + } + + // check If X-Forwarded-Host header is set, if set it must not be the same as service.Endpoint + if cycleDetected(context, service) { + return nil, 0, errCyclicForwardingDetected + } + + log.Logger().Infof("Forwarding Get request to configured server (service=%s)", serviceID) + return m.httpClient.Get(context, service.Endpoint, startAfter) } - return m.store.get(serviceID, tag) + return m.store.get(serviceID, startAfter) +} + +func cycleDetected(ctx context.Context, service ServiceDefinition) bool { + host := forwardedHost(ctx) + if host == "" { + return false + } + myUri, err := url.Parse(host) + if err != nil { + return false + } + targetUri, err := url.Parse(service.Endpoint) + if err != nil { + return false + } + + return myUri.Host == targetUri.Host +} + +func forwardedHost(ctx context.Context) string { + // get value from context using "X-Forwarded-Host" key + forwardedHostValue := ctx.Value(XForwardedHostContextKey{}) + host, ok := forwardedHostValue.(string) + if !ok { + return "" + } + return host } // ActivateServiceForDID is a Discovery Client function that activates a service for a DID. diff --git a/discovery/module_test.go b/discovery/module_test.go index 98346d0a77..7dfba67a46 100644 --- a/discovery/module_test.go +++ b/discovery/module_test.go @@ -43,89 +43,98 @@ func TestModule_Name(t *testing.T) { } func TestModule_Shutdown(t *testing.T) { - module, _, _ := setupModule(t, storage.NewTestStorageEngine(t)) - require.NoError(t, module.Shutdown()) + m, _ := setupModule(t, storage.NewTestStorageEngine(t)) + require.NoError(t, m.Shutdown()) } func Test_Module_Register(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) + ctx := context.Background() - t.Run("not a server", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + t.Run("registration", func(t *testing.T) { + t.Run("ok", func(t *testing.T) { + m, testContext := setupModule(t, storageEngine) + testContext.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil) - err := m.Register("other", vpAlice) - require.EqualError(t, err, "node is not a discovery server for this service") - }) - t.Run("VP verification fails (e.g. invalid signature)", func(t *testing.T) { - m, presentationVerifier, _ := setupModule(t, storageEngine) - presentationVerifier.EXPECT().VerifyVP(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("failed")) + err := m.Register(ctx, testServiceID, vpAlice) + require.NoError(t, err) - err := m.Register(testServiceID, vpAlice) - require.EqualError(t, err, "presentation is invalid for registration\npresentation verification failed: failed") + _, timestamp, err := m.Get(ctx, testServiceID, 0) + require.NoError(t, err) + assert.Equal(t, 1, timestamp) + }) + t.Run("not a server", func(t *testing.T) { + m, _ := setupModule(t, storageEngine, func(module *Module) { + module.allDefinitions["someother"] = ServiceDefinition{ + ID: "someother", + Endpoint: "https://example.com/someother", + } + mockhttpclient := module.httpClient.(*client.MockHTTPClient) + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, 0, nil).AnyTimes() + mockhttpclient.EXPECT().Register(gomock.Any(), "https://example.com/someother", vpAlice).Return(nil) + }) - _, tag, err := m.Get(testServiceID, nil) - require.NoError(t, err) - expectedTag := tagForTimestamp(t, m.store, testServiceID, 0) - assert.Equal(t, expectedTag, *tag) - }) - t.Run("already exists", func(t *testing.T) { - m, presentationVerifier, _ := setupModule(t, storageEngine) - presentationVerifier.EXPECT().VerifyVP(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + err := m.Register(ctx, "someother", vpAlice) - err := m.Register(testServiceID, vpAlice) - assert.NoError(t, err) - err = m.Register(testServiceID, vpAlice) - assert.ErrorIs(t, err, ErrPresentationAlreadyExists) - }) - t.Run("valid for too long", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine, func(module *Module) { - def := module.allDefinitions[testServiceID] - def.PresentationMaxValidity = 1 - module.allDefinitions[testServiceID] = def - module.serverDefinitions[testServiceID] = def + assert.NoError(t, err) }) - err := m.Register(testServiceID, vpAlice) - assert.EqualError(t, err, "presentation is invalid for registration\npresentation is valid for too long (max 1s)") - }) - t.Run("no expiration", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - err := m.Register(testServiceID, createPresentationCustom(aliceDID, func(claims map[string]interface{}, _ *vc.VerifiablePresentation) { - claims[jwt.AudienceKey] = []string{testServiceID} - delete(claims, "exp") - })) - assert.ErrorIs(t, err, errPresentationWithoutExpiration) - }) - t.Run("presentation does not contain an ID", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + t.Run("VP verification fails (e.g. invalid signature)", func(t *testing.T) { + m, testContext := setupModule(t, storageEngine) + testContext.verifier.EXPECT().VerifyVP(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("failed")) - vpWithoutID := createPresentationCustom(aliceDID, func(claims map[string]interface{}, _ *vc.VerifiablePresentation) { - claims[jwt.AudienceKey] = []string{testServiceID} - delete(claims, "jti") - }, vcAlice) - err := m.Register(testServiceID, vpWithoutID) - assert.ErrorIs(t, err, errPresentationWithoutID) - }) - t.Run("not a JWT", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - err := m.Register(testServiceID, vc.VerifiablePresentation{}) - assert.ErrorIs(t, err, errUnsupportedPresentationFormat) - }) - - t.Run("registration", func(t *testing.T) { - t.Run("ok", func(t *testing.T) { - m, presentationVerifier, _ := setupModule(t, storageEngine) - presentationVerifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil) + err := m.Register(ctx, testServiceID, vpAlice) + require.EqualError(t, err, "presentation is invalid for registration\npresentation verification failed: failed") - err := m.Register(testServiceID, vpAlice) + _, timestamp, err := m.Get(ctx, testServiceID, 0) require.NoError(t, err) + assert.Equal(t, 0, timestamp) + }) + t.Run("already exists", func(t *testing.T) { + m, testContext := setupModule(t, storageEngine) + testContext.verifier.EXPECT().VerifyVP(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, tag, err := m.Get(testServiceID, nil) - require.NoError(t, err) - assert.Equal(t, "1", string(*tag)[tagPrefixLength:]) + err := m.Register(ctx, testServiceID, vpAlice) + assert.NoError(t, err) + err = m.Register(ctx, testServiceID, vpAlice) + assert.ErrorIs(t, err, ErrPresentationAlreadyExists) + }) + t.Run("valid for too long", func(t *testing.T) { + m, _ := setupModule(t, storageEngine, func(module *Module) { + def := module.allDefinitions[testServiceID] + def.PresentationMaxValidity = 1 + module.allDefinitions[testServiceID] = def + }) + + err := m.Register(ctx, testServiceID, vpAlice) + + assert.EqualError(t, err, "presentation is invalid for registration\npresentation is valid for too long (max 1s)") + }) + t.Run("no expiration", func(t *testing.T) { + m, _ := setupModule(t, storageEngine) + err := m.Register(ctx, testServiceID, createPresentationCustom(aliceDID, func(claims map[string]interface{}, _ *vc.VerifiablePresentation) { + claims[jwt.AudienceKey] = []string{testServiceID} + delete(claims, "exp") + })) + assert.ErrorIs(t, err, errPresentationWithoutExpiration) + }) + t.Run("presentation does not contain an ID", func(t *testing.T) { + m, _ := setupModule(t, storageEngine) + + vpWithoutID := createPresentationCustom(aliceDID, func(claims map[string]interface{}, _ *vc.VerifiablePresentation) { + claims[jwt.AudienceKey] = []string{testServiceID} + delete(claims, "jti") + }, vcAlice) + err := m.Register(ctx, testServiceID, vpWithoutID) + assert.ErrorIs(t, err, errPresentationWithoutID) + }) + t.Run("not a JWT", func(t *testing.T) { + m, _ := setupModule(t, storageEngine) + err := m.Register(ctx, testServiceID, vc.VerifiablePresentation{}) + assert.ErrorIs(t, err, errUnsupportedPresentationFormat) }) t.Run("valid longer than its credentials", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) vcAlice := createCredential(authorityDID, aliceDID, nil, func(claims map[string]interface{}) { claims[jwt.AudienceKey] = []string{testServiceID} @@ -134,21 +143,36 @@ func Test_Module_Register(t *testing.T) { vpAlice := createPresentationCustom(aliceDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { claims[jwt.AudienceKey] = []string{testServiceID} }, vcAlice) - err := m.Register(testServiceID, vpAlice) + err := m.Register(ctx, testServiceID, vpAlice) assert.ErrorIs(t, err, errPresentationValidityExceedsCredentials) }) t.Run("not conform to Presentation Definition", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) // Presentation Definition only allows did:example DIDs otherVP := createPresentationCustom(unsupportedDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { claims[jwt.AudienceKey] = []string{testServiceID} }, createCredential(unsupportedDID, unsupportedDID, nil, nil)) - err := m.Register(testServiceID, otherVP) + err := m.Register(ctx, testServiceID, otherVP) require.ErrorContains(t, err, "presentation does not fulfill Presentation ServiceDefinition") - _, tag, _ := m.Get(testServiceID, nil) - assert.Equal(t, "0", string(*tag)[tagPrefixLength:]) + _, timestamp, _ := m.Get(ctx, testServiceID, 0) + assert.Equal(t, 0, timestamp) + }) + t.Run("cycle detected", func(t *testing.T) { + m, _ := setupModule(t, storageEngine, func(module *Module) { + module.allDefinitions["someother"] = ServiceDefinition{ + ID: "someother", + Endpoint: "https://example.com/someother", + } + mockhttpclient := module.httpClient.(*client.MockHTTPClient) + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", gomock.Any()).Return(nil, 0, nil).AnyTimes() + }) + ctx := context.WithValue(ctx, XForwardedHostContextKey{}, "https://example.com") + + err := m.Register(ctx, "someother", vc.VerifiablePresentation{}) + + assert.ErrorIs(t, err, errCyclicForwardingDetected) }) }) t.Run("retraction", func(t *testing.T) { @@ -158,55 +182,55 @@ func Test_Module_Register(t *testing.T) { claims[jwt.AudienceKey] = []string{testServiceID} }) t.Run("ok", func(t *testing.T) { - m, presentationVerifier, _ := setupModule(t, storageEngine) - presentationVerifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Times(2) + m, testContext := setupModule(t, storageEngine) + testContext.verifier.EXPECT().VerifyVP(gomock.Any(), true, true, nil).Times(2) - err := m.Register(testServiceID, vpAlice) + err := m.Register(ctx, testServiceID, vpAlice) require.NoError(t, err) - err = m.Register(testServiceID, vpAliceRetract) + err = m.Register(ctx, testServiceID, vpAliceRetract) assert.NoError(t, err) }) t.Run("non-existent presentation", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - err := m.Register(testServiceID, vpAliceRetract) + m, _ := setupModule(t, storageEngine) + err := m.Register(ctx, testServiceID, vpAliceRetract) assert.ErrorIs(t, err, errRetractionReferencesUnknownPresentation) }) t.Run("must not contain credentials", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) vp := createPresentationCustom(aliceDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { vp.Type = append(vp.Type, retractionPresentationType) claims[jwt.AudienceKey] = []string{testServiceID} }, vcAlice) - err := m.Register(testServiceID, vp) + err := m.Register(ctx, testServiceID, vp) assert.ErrorIs(t, err, errRetractionContainsCredentials) }) t.Run("missing 'retract_jti' claim", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) vp := createPresentationCustom(aliceDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { vp.Type = append(vp.Type, retractionPresentationType) claims[jwt.AudienceKey] = []string{testServiceID} }) - err := m.Register(testServiceID, vp) + err := m.Register(ctx, testServiceID, vp) assert.ErrorIs(t, err, errInvalidRetractionJTIClaim) }) t.Run("'retract_jti' claim is not a string", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) vp := createPresentationCustom(aliceDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { vp.Type = append(vp.Type, retractionPresentationType) claims["retract_jti"] = 10 claims[jwt.AudienceKey] = []string{testServiceID} }) - err := m.Register(testServiceID, vp) + err := m.Register(ctx, testServiceID, vp) assert.ErrorIs(t, err, errInvalidRetractionJTIClaim) }) t.Run("'retract_jti' claim is an empty string", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) vp := createPresentationCustom(aliceDID, func(claims map[string]interface{}, vp *vc.VerifiablePresentation) { vp.Type = append(vp.Type, retractionPresentationType) claims["retract_jti"] = "" claims[jwt.AudienceKey] = []string{testServiceID} }) - err := m.Register(testServiceID, vp) + err := m.Register(ctx, testServiceID, vp) assert.ErrorIs(t, err, errInvalidRetractionJTIClaim) }) }) @@ -215,29 +239,62 @@ func Test_Module_Register(t *testing.T) { func Test_Module_Get(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) + ctx := context.Background() t.Run("ok", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, "")) - presentations, tag, err := m.Get(testServiceID, nil) + m, _ := setupModule(t, storageEngine) + require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) + presentations, timestamp, err := m.Get(ctx, testServiceID, 0) assert.NoError(t, err) - assert.Equal(t, []vc.VerifiablePresentation{vpAlice}, presentations) - assert.Equal(t, "1", string(*tag)[tagPrefixLength:]) + assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) + assert.Equal(t, 1, timestamp) }) t.Run("ok - retrieve delta", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, "")) - presentations, _, err := m.Get(testServiceID, nil) + m, _ := setupModule(t, storageEngine) + require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) + presentations, _, err := m.Get(ctx, testServiceID, 0) require.NoError(t, err) require.Len(t, presentations, 1) }) - t.Run("not a server for this service ID", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - _, _, err := m.Get("other", nil) - assert.ErrorIs(t, err, ErrServerModeDisabled) + t.Run("not a server for this service ID, call forwarded", func(t *testing.T) { + m, _ := setupModule(t, storageEngine, func(module *Module) { + module.allDefinitions["someother"] = ServiceDefinition{ + ID: "someother", + Endpoint: "https://example.com/someother", + } + mockhttpclient := module.httpClient.(*client.MockHTTPClient) + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(map[string]vc.VerifiablePresentation{"1": vpAlice}, 1, nil).AnyTimes() + }) + + presentations, timestamp, err := m.Get(ctx, "someother", 0) + + require.NoError(t, err) + assert.Equal(t, 1, timestamp) + assert.Len(t, presentations, 1) + }) + t.Run("not a server for this service ID, call forwarded, cycle detected", func(t *testing.T) { + m, _ := setupModule(t, storageEngine, func(module *Module) { + module.allDefinitions["someother"] = ServiceDefinition{ + ID: "someother", + Endpoint: "https://example.com/someother", + } + mockhttpclient := module.httpClient.(*client.MockHTTPClient) + mockhttpclient.EXPECT().Get(gomock.Any(), "https://example.com/someother", 0).Return(nil, 0, nil).AnyTimes() + }) + ctx := context.WithValue(ctx, XForwardedHostContextKey{}, "https://example.com") + + _, _, err := m.Get(ctx, "someother", 0) + + assert.ErrorIs(t, err, errCyclicForwardingDetected) }) } -func setupModule(t *testing.T, storageInstance storage.Engine, visitors ...func(*Module)) (*Module, *verifier.MockVerifier, *management.MockDocumentOwner) { +type mockContext struct { + ctrl *gomock.Controller + documentOwner *management.MockDocumentOwner + verifier *verifier.MockVerifier +} + +func setupModule(t *testing.T, storageInstance storage.Engine, visitors ...func(module *Module)) (*Module, mockContext) { resetStore(t, storageInstance.GetSQLDatabase()) ctrl := gomock.NewController(t) mockVerifier := verifier.NewMockVerifier(ctrl) @@ -248,20 +305,27 @@ func setupModule(t *testing.T, storageInstance storage.Engine, visitors ...func( m.config = DefaultConfig() require.NoError(t, m.Configure(core.TestServerConfig())) httpClient := client.NewMockHTTPClient(ctrl) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", nil).AnyTimes() + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/other", gomock.Any()).Return(nil, 0, nil).AnyTimes() + httpClient.EXPECT().Get(gomock.Any(), "http://example.com/usecase", gomock.Any()).Return(nil, 0, nil).AnyTimes() m.httpClient = httpClient m.allDefinitions = testDefinitions() m.serverDefinitions = map[string]ServiceDefinition{ testServiceID: m.allDefinitions[testServiceID], } + for _, visitor := range visitors { visitor(m) } + require.NoError(t, m.Start()) t.Cleanup(func() { _ = m.Shutdown() }) - return m, mockVerifier, documentOwner + return m, mockContext{ + ctrl: ctrl, + documentOwner: documentOwner, + verifier: mockVerifier, + } } func TestModule_Configure(t *testing.T) { @@ -317,8 +381,10 @@ func TestModule_Search(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) t.Run("ok", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) - require.NoError(t, m.store.add(testServiceID, vpAlice, "")) + m, _ := setupModule(t, storageEngine) + + require.NoError(t, m.store.add(testServiceID, vpAlice, 0)) + results, err := m.Search(testServiceID, map[string]string{ "credentialSubject.id": aliceDID.String(), }) @@ -333,7 +399,7 @@ func TestModule_Search(t *testing.T) { assert.JSONEq(t, string(expectedJSON), string(actualJSON)) }) t.Run("unknown service ID", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) _, err := m.Search("unknown", nil) assert.ErrorIs(t, err, ErrServiceNotFound) }) @@ -343,25 +409,25 @@ func TestModule_update(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) t.Run("Start() initiates update", func(t *testing.T) { - _, _, _ = setupModule(t, storageEngine, func(module *Module) { + _, _ = setupModule(t, storageEngine, func(module *Module) { // we want to assert the job runs, so make it run very often to make the test faster module.config.Client.RefreshInterval = 1 * time.Millisecond // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) // Get() should be called at least twice (times the number of Service Definitions), once for the initial run on startup, then again after the refresh interval - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", nil).MinTimes(2 * len(module.allDefinitions)) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil).MinTimes(2 * len(module.allDefinitions)) module.httpClient = httpClient }) time.Sleep(10 * time.Millisecond) }) t.Run("update() runs on node startup", func(t *testing.T) { - _, _, _ = setupModule(t, storageEngine, func(module *Module) { + _, _ = setupModule(t, storageEngine, func(module *Module) { // we want to assert the job immediately executes on node startup, even if the refresh interval hasn't passed module.config.Client.RefreshInterval = time.Hour // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) // update causes call to HttpClient.Get(), once for each Service Definition - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", nil).Times(len(module.allDefinitions)) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil).Times(len(module.allDefinitions)) module.httpClient = httpClient }) }) @@ -371,11 +437,11 @@ func TestModule_ActivateServiceForDID(t *testing.T) { t.Run("ok, syncs VPs immediately after registration", func(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - m, _, documentOwner := setupModule(t, storageEngine, func(module *Module) { + m, testContext := setupModule(t, storageEngine, func(module *Module) { // overwrite httpClient mock for custom behavior assertions (we want to know how often HttpClient.Get() was called) httpClient := client.NewMockHTTPClient(gomock.NewController(t)) httpClient.EXPECT().Register(gomock.Any(), gomock.Any(), vpAlice).Return(nil) - httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, "", nil) + httpClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, 0, nil) module.httpClient = httpClient // disable auto-refresh job to have deterministic assertions module.config.Client.RefreshInterval = 0 @@ -385,7 +451,7 @@ func TestModule_ActivateServiceForDID(t *testing.T) { m.vcrInstance.(*vcr.MockVCR).EXPECT().Wallet().Return(wallet).MinTimes(1) wallet.EXPECT().List(gomock.Any(), gomock.Any()).Return([]vc.VerifiableCredential{vcAlice}, nil) wallet.EXPECT().BuildPresentation(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&vpAlice, nil) - documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(true, nil) + testContext.documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(true, nil) err := m.ActivateServiceForDID(context.Background(), testServiceID, aliceDID) @@ -394,8 +460,8 @@ func TestModule_ActivateServiceForDID(t *testing.T) { t.Run("not owned", func(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - m, _, documentOwner := setupModule(t, storageEngine) - documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(false, nil) + m, testContext := setupModule(t, storageEngine) + testContext.documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(false, nil) err := m.ActivateServiceForDID(context.Background(), testServiceID, aliceDID) @@ -404,11 +470,11 @@ func TestModule_ActivateServiceForDID(t *testing.T) { t.Run("ok, but couldn't register presentation -> maps to ErrRegistrationFailed", func(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - m, _, documentOwner := setupModule(t, storageEngine) + m, testContext := setupModule(t, storageEngine) wallet := holder.NewMockWallet(gomock.NewController(t)) m.vcrInstance.(*vcr.MockVCR).EXPECT().Wallet().Return(wallet).MinTimes(1) wallet.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, errors.New("failed")).MinTimes(1) - documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(true, nil) + testContext.documentOwner.EXPECT().IsOwner(gomock.Any(), aliceDID).Return(true, nil) err := m.ActivateServiceForDID(context.Background(), testServiceID, aliceDID) @@ -431,7 +497,7 @@ func TestModule_GetServiceActivation(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) t.Run("not activated", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) activated, presentation, err := m.GetServiceActivation(context.Background(), testServiceID, aliceDID) @@ -440,7 +506,7 @@ func TestModule_GetServiceActivation(t *testing.T) { assert.Nil(t, presentation) }) t.Run("activated, no VP", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) next := time.Now() _ = m.store.updatePresentationRefreshTime(testServiceID, aliceDID, &next) @@ -451,10 +517,10 @@ func TestModule_GetServiceActivation(t *testing.T) { assert.Nil(t, presentation) }) t.Run("activated, with VP", func(t *testing.T) { - m, _, _ := setupModule(t, storageEngine) + m, _ := setupModule(t, storageEngine) next := time.Now() _ = m.store.updatePresentationRefreshTime(testServiceID, aliceDID, &next) - _ = m.store.add(testServiceID, vpAlice, "") + _ = m.store.add(testServiceID, vpAlice, 0) activated, presentation, err := m.GetServiceActivation(context.Background(), testServiceID, aliceDID) diff --git a/discovery/store.go b/discovery/store.go index 8ffe5f00e2..428fe55b67 100644 --- a/discovery/store.go +++ b/discovery/store.go @@ -23,7 +23,6 @@ import ( "fmt" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/vcr/credential/store" - "math/rand" "time" "github.com/google/uuid" @@ -35,12 +34,9 @@ import ( "gorm.io/gorm/schema" ) -const tagPrefixLength = 5 - type serviceRecord struct { - ID string `gorm:"primaryKey"` - LastTag Tag - TagPrefix string + ID string `gorm:"primaryKey"` + LastLamportTimestamp int } func (s serviceRecord) TableName() string { @@ -52,7 +48,7 @@ var _ schema.Tabler = (*presentationRecord)(nil) type presentationRecord struct { ID string `gorm:"primaryKey"` ServiceID string - LamportTimestamp uint64 + LamportTimestamp int CredentialSubjectID string PresentationID string PresentationRaw string @@ -86,7 +82,7 @@ type presentationRefreshRecord struct { ServiceID string `gorm:"primaryKey"` // Did is Did that should be registered on the service. Did string `gorm:"primaryKey"` - // NextRefresh is the timestamp (seconds since Unix epoch) when the registration on the Discovery Service should be refreshed. + // NextRefresh is the Timestamp (seconds since Unix epoch) when the registration on the Discovery Service should be refreshed. NextRefresh int64 } @@ -99,16 +95,12 @@ type sqlStore struct { db *gorm.DB } -func newSQLStore(db *gorm.DB, clientDefinitions map[string]ServiceDefinition, serverDefinitions map[string]ServiceDefinition) (*sqlStore, error) { +func newSQLStore(db *gorm.DB, clientDefinitions map[string]ServiceDefinition) (*sqlStore, error) { // Creates entries in the discovery service table, if they don't exist yet for _, definition := range clientDefinitions { currentList := serviceRecord{ ID: definition.ID, } - // If the node is server for this discovery service, make sure the timestamp prefix is set. - if _, isServer := serverDefinitions[definition.ID]; isServer { - currentList.TagPrefix = generatePrefix() - } if err := db.FirstOrCreate(¤tList, "id = ?", definition.ID).Error; err != nil { return nil, err } @@ -116,11 +108,9 @@ func newSQLStore(db *gorm.DB, clientDefinitions map[string]ServiceDefinition, se return &sqlStore{db: db}, nil } -// Add adds a presentation to the list of presentations. -// A non-empty Tag should be passed if the presentation was received from a remote Discovery Server, then it is stored alongside the presentation. -// If the local node is the Discovery Server and thus is responsible for the timestamping, -// an empty Tag should be passed to let the store determine the right value. -func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, tag Tag) error { +// add adds a presentation to the list of presentations. +// If the given timestamp is 0, the server will assign a timestamp. +func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, timestamp int) error { credentialSubjectID, err := credential.PresentationSigner(presentation) if err != nil { return err @@ -129,7 +119,13 @@ func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, return err } return s.db.Transaction(func(tx *gorm.DB) error { - newTimestamp, err := s.updateTag(tx, serviceID, tag) + if timestamp == 0 { + var newTs *int + newTs, err = s.incrementTimestamp(tx, serviceID) + timestamp = *newTs + } else { + err = s.setTimestamp(tx, serviceID, timestamp) + } if err != nil { return err } @@ -139,12 +135,12 @@ func (s *sqlStore) add(serviceID string, presentation vc.VerifiablePresentation, return err } - return storePresentation(tx, serviceID, newTimestamp, presentation) + return storePresentation(tx, serviceID, timestamp, presentation) }) } // storePresentation creates a presentationRecord from a VerifiablePresentation and stores it, with its credentials, in the database. -func storePresentation(tx *gorm.DB, serviceID string, timestamp *Timestamp, presentation vc.VerifiablePresentation) error { +func storePresentation(tx *gorm.DB, serviceID string, timestamp int, presentation vc.VerifiablePresentation) error { credentialSubjectID, err := credential.PresentationSigner(presentation) if err != nil { return err @@ -154,13 +150,11 @@ func storePresentation(tx *gorm.DB, serviceID string, timestamp *Timestamp, pres ID: uuid.NewString(), ServiceID: serviceID, CredentialSubjectID: credentialSubjectID.String(), + LamportTimestamp: timestamp, PresentationID: presentation.ID.String(), PresentationRaw: presentation.Raw(), PresentationExpiration: presentation.JWT().Expiration().Unix(), } - if timestamp != nil { - newPresentation.LamportTimestamp = uint64(*timestamp) - } credentialStore := store.CredentialStore{} for _, verifiableCredential := range presentation.VerifiableCredential { @@ -178,42 +172,28 @@ func storePresentation(tx *gorm.DB, serviceID string, timestamp *Timestamp, pres return tx.Create(&newPresentation).Error } -// get returns all presentations, registered on the given service, starting after the given tag. -// It also returns the latest tag of the returned presentations. -// This tag can then be used next time to only retrieve presentations that were added after that tag. -func (s *sqlStore) get(serviceID string, tag *Tag) ([]vc.VerifiablePresentation, *Tag, error) { +// get returns all presentations, registered on the given service, starting after the given timestamp. +// It also returns the latest timestamp of the returned presentations. +func (s *sqlStore) get(serviceID string, startAfter int) (map[string]vc.VerifiablePresentation, int, error) { var service serviceRecord if err := s.db.Find(&service, "id = ?", serviceID).Error; err != nil { - return nil, nil, fmt.Errorf("query service '%s': %w", serviceID, err) - } - var startAfter uint64 - if tag != nil { - // Decode tag - lamportTimestamp := tag.Timestamp(service.TagPrefix) - if lamportTimestamp != nil { - startAfter = uint64(*lamportTimestamp) - } + return nil, 0, fmt.Errorf("query service '%s': %w", serviceID, err) } var rows []presentationRecord err := s.db.Order("lamport_timestamp ASC").Find(&rows, "service_id = ? AND lamport_timestamp > ?", serviceID, startAfter).Error if err != nil { - return nil, nil, fmt.Errorf("query service '%s': %w", serviceID, err) + return nil, 0, fmt.Errorf("query service '%s': %w", serviceID, err) } - presentations := make([]vc.VerifiablePresentation, 0, len(rows)) + presentations := make(map[string]vc.VerifiablePresentation, len(rows)) for _, row := range rows { presentation, err := vc.ParseVerifiablePresentation(row.PresentationRaw) if err != nil { - return nil, nil, fmt.Errorf("parse presentation '%s' of service '%s': %w", row.PresentationID, serviceID, err) + return nil, 0, fmt.Errorf("parse presentation '%s' of service '%s': %w", row.PresentationID, serviceID, err) } - presentations = append(presentations, *presentation) + presentations[fmt.Sprintf("%d", row.LamportTimestamp)] = *presentation } - lastTag := service.LastTag - if lastTag.Empty() { - // Make sure we don't return an empty string for the tag, instead return tag indicating the beginning of the list. - lastTag = Timestamp(0).Tag(service.TagPrefix) - } - return presentations, &lastTag, nil + return presentations, service.LastLamportTimestamp, nil } // search searches for presentations, registered on the given service, matching the given query. @@ -244,12 +224,10 @@ func (s *sqlStore) search(serviceID string, query map[string]string) ([]vc.Verif return results, nil } -// updateTag updates the tag of the given service. -// Clients should pass the tag they received from the server (which simply sets it). -// Servers should pass an empty tag (since they "own" the tag), which causes it to be incremented. -func (s *sqlStore) updateTag(tx *gorm.DB, serviceID string, newTimestamp Tag) (*Timestamp, error) { +// incrementTimestamp increments the last_timestamp of the given service. +func (s *sqlStore) incrementTimestamp(tx *gorm.DB, serviceID string) (*int, error) { var service serviceRecord - // Lock (SELECT FOR UPDATE) discovery_service row to prevent concurrent updates to the same list, which could mess up the lamport timestamp. + // Lock (SELECT FOR UPDATE) discovery_service row to prevent concurrent updates to the same list, which could mess up the last Timestamp. if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where(serviceRecord{ID: serviceID}). Find(&service). @@ -257,30 +235,27 @@ func (s *sqlStore) updateTag(tx *gorm.DB, serviceID string, newTimestamp Tag) (* return nil, err } service.ID = serviceID - var result *Timestamp - if newTimestamp.Empty() { - // Update tag: decode current timestamp, increment it, encode it again. - currTimestamp := Timestamp(0) - if service.LastTag != "" { - // If LastTag is empty, it means the service was just created and no presentations were added yet. - ts := service.LastTag.Timestamp(service.TagPrefix) - if ts == nil { - // would be very weird: can't decode it, although it's our own tag - return nil, fmt.Errorf("can't decode tag '%s', did someone alter 'service.tag_prefix' or 'service.last_tag' in the database?", service.LastTag) - } - currTimestamp = *ts - } - ts := currTimestamp.Increment() - result = &ts - service.LastTag = ts.Tag(service.TagPrefix) - } else { - // Set tag: just store it - service.LastTag = newTimestamp - } + service.LastLamportTimestamp = service.LastLamportTimestamp + 1 + if err := tx.Save(service).Error; err != nil { return nil, err } - return result, nil + return &service.LastLamportTimestamp, nil +} + +// setTimestamp sets the last_timestamp of the given service. +func (s *sqlStore) setTimestamp(tx *gorm.DB, serviceID string, timestamp int) error { + var service serviceRecord + // Lock (SELECT FOR UPDATE) discovery_service row to prevent concurrent updates to the same list, which could mess up the last Timestamp. + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where(serviceRecord{ID: serviceID}). + Find(&service). + Error; err != nil { + return err + } + service.ID = serviceID + service.LastLamportTimestamp = timestamp + return tx.Save(service).Error } // exists checks whether a presentation of the given subject is registered on a service. @@ -361,27 +336,13 @@ func (s *sqlStore) getPresentationsToBeRefreshed(now time.Time) ([]string, []did return serviceIDs, dids, nil } -func (s *sqlStore) getTag(serviceID string) (Tag, error) { +func (s *sqlStore) getTimestamp(serviceID string) (int, error) { var service serviceRecord err := s.db.Find(&service, "id = ?", serviceID).Error if errors.Is(err, gorm.ErrRecordNotFound) { - return "", nil + return 0, nil } else if err != nil { - return "", fmt.Errorf("query service '%s': %w", serviceID, err) - } - if service.LastTag.Empty() { - return "", nil - } - return service.LastTag, nil -} - -// generatePrefix generates a random seed for a service, consisting of 5 uppercase letters. -func generatePrefix() string { - result := make([]byte, tagPrefixLength) - lower := int('A') - upper := int('Z') - for i := 0; i < len(result); i++ { - result[i] = byte(lower + rand.Intn(upper-lower)) + return 0, fmt.Errorf("query service '%s': %w", serviceID, err) } - return string(result) + return service.LastLamportTimestamp, nil } diff --git a/discovery/store_test.go b/discovery/store_test.go index 907a2649a4..fbe1afa0ab 100644 --- a/discovery/store_test.go +++ b/discovery/store_test.go @@ -43,21 +43,21 @@ func Test_sqlStore_exists(t *testing.T) { }) t.Run("non-empty list, no match (other subject and ID)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpBob, "")) + require.NoError(t, m.add(testServiceID, vpBob, 0)) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, no match (other list)", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) exists, err := m.exists("other", aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.False(t, exists) }) t.Run("non-empty list, match", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) assert.NoError(t, err) assert.True(t, exists) @@ -70,16 +70,27 @@ func Test_sqlStore_add(t *testing.T) { t.Run("no credentials in presentation", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - err := m.add(testServiceID, createPresentation(aliceDID), "") + err := m.add(testServiceID, createPresentation(aliceDID), 0) assert.NoError(t, err) }) + t.Run("passing timestamp updates last_timestamp", func(t *testing.T) { + m := setupStore(t, storageEngine.GetSQLDatabase()) + err := m.add(testServiceID, createPresentation(aliceDID), 1) + require.NoError(t, err) + + timestamp, err := m.getTimestamp(testServiceID) + + require.NoError(t, err) + assert.Equal(t, 1, timestamp) + }) + t.Run("replaces previous presentation of same subject", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) secondVP := createPresentation(aliceDID, vcAlice) - require.NoError(t, m.add(testServiceID, vpAlice, "")) - require.NoError(t, m.add(testServiceID, secondVP, "")) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, secondVP, 0)) // First VP should not exist exists, err := m.exists(testServiceID, aliceDID.String(), vpAlice.ID.String()) @@ -97,53 +108,47 @@ func Test_sqlStore_get(t *testing.T) { storageEngine := storage.NewTestStorageEngine(t) require.NoError(t, storageEngine.Start()) - t.Run("empty list, empty tag", func(t *testing.T) { + t.Run("empty list, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - presentations, tag, err := m.get(testServiceID, nil) + presentations, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) assert.Empty(t, presentations) - expectedTag := tagForTimestamp(t, m, testServiceID, 0) - assert.Equal(t, expectedTag, *tag) + assert.Equal(t, 0, timestamp) }) - t.Run("1 entry, empty tag", func(t *testing.T) { + t.Run("1 entry, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) - presentations, tag, err := m.get(testServiceID, nil) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) + presentations, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) - assert.Equal(t, []vc.VerifiablePresentation{vpAlice}, presentations) - expectedTag := tagForTimestamp(t, m, testServiceID, 1) - assert.Equal(t, expectedTag, *tag) + assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice}, presentations) + assert.Equal(t, 1, timestamp) }) - t.Run("2 entries, empty tag", func(t *testing.T) { + t.Run("2 entries, 0 timestamp", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) - require.NoError(t, m.add(testServiceID, vpBob, "")) - presentations, tag, err := m.get(testServiceID, nil) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, vpBob, 0)) + presentations, timestamp, err := m.get(testServiceID, 0) assert.NoError(t, err) - assert.Equal(t, []vc.VerifiablePresentation{vpAlice, vpBob}, presentations) - expectedTS := tagForTimestamp(t, m, testServiceID, 2) - assert.Equal(t, expectedTS, *tag) + assert.Equal(t, map[string]vc.VerifiablePresentation{"1": vpAlice, "2": vpBob}, presentations) + assert.Equal(t, 2, timestamp) }) t.Run("2 entries, start after first", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) - require.NoError(t, m.add(testServiceID, vpBob, "")) - ts := tagForTimestamp(t, m, testServiceID, 1) - presentations, tag, err := m.get(testServiceID, &ts) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, vpBob, 0)) + presentations, timestamp, err := m.get(testServiceID, 1) assert.NoError(t, err) - assert.Equal(t, []vc.VerifiablePresentation{vpBob}, presentations) - expectedTS := tagForTimestamp(t, m, testServiceID, 2) - assert.Equal(t, expectedTS, *tag) + assert.Equal(t, map[string]vc.VerifiablePresentation{"2": vpBob}, presentations) + assert.Equal(t, 2, timestamp) }) t.Run("2 entries, start at end", func(t *testing.T) { m := setupStore(t, storageEngine.GetSQLDatabase()) - require.NoError(t, m.add(testServiceID, vpAlice, "")) - require.NoError(t, m.add(testServiceID, vpBob, "")) - expectedTag := tagForTimestamp(t, m, testServiceID, 2) - presentations, tag, err := m.get(testServiceID, &expectedTag) + require.NoError(t, m.add(testServiceID, vpAlice, 0)) + require.NoError(t, m.add(testServiceID, vpBob, 0)) + presentations, timestamp, err := m.get(testServiceID, 2) assert.NoError(t, err) - assert.Equal(t, []vc.VerifiablePresentation{}, presentations) - assert.Equal(t, expectedTag, *tag) + assert.Equal(t, map[string]vc.VerifiablePresentation{}, presentations) + assert.Equal(t, 2, timestamp) }) t.Run("concurrency", func(t *testing.T) { c := setupStore(t, storageEngine.GetSQLDatabase()) @@ -152,7 +157,7 @@ func Test_sqlStore_get(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), "") + err := c.add(testServiceID, createPresentation(aliceDID, vcAlice), 0) require.NoError(t, err) }() } @@ -178,7 +183,7 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, "") + err := c.add(testServiceID, vp, 0) require.NoError(t, err) } @@ -193,7 +198,7 @@ func Test_sqlStore_search(t *testing.T) { vps := []vc.VerifiablePresentation{vpAlice, vpBob} c := setupStore(t, storageEngine.GetSQLDatabase()) for _, vp := range vps { - err := c.add(testServiceID, vp, "") + err := c.add(testServiceID, vp, 0) require.NoError(t, err) } actualVPs, err := c.search(testServiceID, map[string]string{ @@ -277,7 +282,7 @@ func Test_sqlStore_getPresentationRefreshTime(t *testing.T) { func setupStore(t *testing.T, db *gorm.DB) *sqlStore { resetStore(t, db) defs := testDefinitions() - store, err := newSQLStore(db, defs, defs) + store, err := newSQLStore(db, defs) require.NoError(t, err) return store } @@ -289,20 +294,3 @@ func resetStore(t *testing.T, db *gorm.DB) { require.NoError(t, db.Exec("DELETE FROM "+tableName).Error) } } - -func Test_generateSeed(t *testing.T) { - for i := 0; i < 100; i++ { - seed := generatePrefix() - assert.Len(t, seed, 5) - for _, char := range seed { - assert.True(t, char >= 'A' && char <= 'Z') - } - } -} - -func tagForTimestamp(t *testing.T, store *sqlStore, serviceID string, ts int) Tag { - t.Helper() - var service serviceRecord - require.NoError(t, store.db.Find(&service, "id = ?", serviceID).Error) - return Timestamp(ts).Tag(service.TagPrefix) -} diff --git a/docs/_static/discovery/server.yaml b/docs/_static/discovery/server.yaml index 03d4a95410..4bf8547842 100644 --- a/docs/_static/discovery/server.yaml +++ b/docs/_static/discovery/server.yaml @@ -16,9 +16,9 @@ paths: get: summary: Retrieves the presentations of a Discovery Service. description: | - An API provided by the discovery server to retrieve the presentations of a Discovery Service, starting at the given tag. - The client should provide the tag it was returned in the last response. - If no tag is given, it will return all presentations. + An API provided by the discovery server to retrieve the presentations of a Discovery Service, starting from the given timestamp. + The client should provide the timestamp it was returned in the last response. + If no timestamp is given, it will return all presentations. error returns: * 404 - unknown service ID @@ -26,13 +26,13 @@ paths: tags: - discovery parameters: - - name: tag + - name: timestamp in: query schema: - type: string + type: integer responses: "200": - description: Presentations are returned, alongside the tag which should be provided at the next query. + description: Presentations are returned, alongside the timestamp which should be provided at the next query. content: application/json: schema: @@ -76,14 +76,16 @@ components: PresentationsResponse: type: object required: - - tag + - timestamp - entries properties: - tag: - type: string + timestamp: + description: highest timestamp of the returned presentations, should be used as the timestamp for the next query + type: integer entries: - type: array - items: + type: object + description: A map of timestamp (as string) to presentation. + additionalProperties: $ref: "#/components/schemas/VerifiablePresentation" SearchResult: type: object diff --git a/storage/sql_migrations/002_discoveryservice.sql b/storage/sql_migrations/002_discoveryservice.sql index d06b52c7f5..2338f8d452 100644 --- a/storage/sql_migrations/002_discoveryservice.sql +++ b/storage/sql_migrations/002_discoveryservice.sql @@ -1,15 +1,12 @@ -- +goose ENVSUB ON -- +goose Up --- discovery contains the known discovery services and the associated tags. +-- discovery contains the known discovery services and the associated timestamp. create table discovery_service ( -- id is the unique identifier for the service. It comes from the service definition. - id varchar(200) not null primary key, - -- tag is the latest tag pointing to the last presentation registered on the service. - last_tag varchar(40) null, - -- tag_prefix is used to prefix the tag of the presentations of the service. - -- It is only populated if the node is server for this service. - tag_prefix varchar(5) null + id varchar(200) not null primary key, + -- last_lamport_timestamp is the latest lamport_timestamp pointing to the last presentation registered on the service. + last_lamport_timestamp integer not null ); -- discovery_presentation contains the presentations of the discovery services @@ -17,9 +14,7 @@ create table discovery_presentation ( id varchar(36) not null primary key, service_id varchar(200) not null, - -- lamport_timestamp is the lamport clock of the presentation, converted to a tag and then returned to the client. - -- It is only populated if the node is server for this service. - lamport_timestamp integer null, + lamport_timestamp integer not null, credential_subject_id varchar(370) not null, presentation_id varchar(415) not null, presentation_raw $TEXT_TYPE not null,