diff --git a/auth/api/iam/openid4vp.go b/auth/api/iam/openid4vp.go index b778211e1d..28dd93cb86 100644 --- a/auth/api/iam/openid4vp.go +++ b/auth/api/iam/openid4vp.go @@ -280,9 +280,7 @@ func (r Wrapper) handleAuthorizeRequestFromVerifier(ctx context.Context, tenantD return r.sendAndHandleDirectPostError(ctx, oauth.OAuth2Error{Code: oauth.InvalidRequest, Description: "missing nonce parameter"}, responseURI, state) } - // TODO: Create session if it does not exist (use client state to get original Authorization Code request)? - // Although it would be quite weird (maybe it expired). - userSession, err := usersession.Get(ctx, tenantDID) + userSession, err := usersession.Get(ctx) if userSession == nil { return nil, oauth.OAuth2Error{Code: oauth.InvalidRequest, InternalError: err, Description: "no user session found"} } diff --git a/auth/api/iam/user.go b/auth/api/iam/user.go index ae10e085e7..91454c9401 100644 --- a/auth/api/iam/user.go +++ b/auth/api/iam/user.go @@ -82,8 +82,7 @@ func (r Wrapper) handleUserLanding(echoCtx echo.Context) error { } // Make sure there's a user session, loaded with EmployeeCredential - tenantDID, _ := did.ParseDID(accessTokenRequest.Did) // can't fail, since the request was created earlier (thus, validated) - userSession, err := usersession.Get(echoCtx.Request().Context(), *tenantDID) + userSession, err := usersession.Get(echoCtx.Request().Context()) if err != nil { return err } @@ -163,18 +162,18 @@ func (r Wrapper) provisionUserSession(ctx context.Context, session *usersession. return session.Save() } -func (r Wrapper) issueEmployeeCredential(ctx context.Context, data usersession.Data, userDetails UserDetails) (*vc.VerifiableCredential, error) { +func (r Wrapper) issueEmployeeCredential(ctx context.Context, session usersession.Data, userDetails UserDetails) (*vc.VerifiableCredential, error) { issuanceDate := time.Now() - expirationDate := data.ExpiresAt + expirationDate := session.ExpiresAt template := vc.VerifiableCredential{ Context: []ssi.URI{credential.NutsV1ContextURI}, Type: []ssi.URI{ssi.MustParseURI("EmployeeCredential")}, - Issuer: data.TenantDID.URI(), + Issuer: session.TenantDID.URI(), IssuanceDate: issuanceDate, ExpirationDate: &expirationDate, CredentialSubject: []interface{}{ map[string]string{ - "id": data.Wallet.DID.String(), + "id": session.Wallet.DID.String(), "identifier": userDetails.Id, "name": userDetails.Name, "roleName": userDetails.Role, diff --git a/auth/api/iam/usersession/user_session.go b/auth/api/iam/usersession/user_session.go index 0eff07f3de..a79d1f9c94 100644 --- a/auth/api/iam/usersession/user_session.go +++ b/auth/api/iam/usersession/user_session.go @@ -38,15 +38,15 @@ import ( var userSessionContextKey = struct{}{} // userSessionCookieName is the name of the cookie used to store the user session. -// It uses the __Host prefix, that instructs the user agent to treat it as a secure cookie: +// It uses the __Secure prefix, that instructs the user agent to treat it as a secure cookie: // - Must be set with the Secure attribute // - Must be set from an HTTPS uri -// - Must not contain a Domain attribute -// - Must contain a Path attribute +// Note that earlier, we used the Host cookie prefix, but that doesn't work in a multi-tenant environment, +// since then the Path attribute (used for multi-tenancy) can't be used. // Also see: // - https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/02-Testing_for_Cookies_Attributes // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies -const userSessionCookieName = "__Host-SID" +const userSessionCookieName = "__Secure-SID" // Middleware is Echo middleware that ensures a user session is available in the request context (unless skipped). // If no session is available, a new session is created. @@ -77,15 +77,15 @@ func (u Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc { return fmt.Errorf("invalid tenant DID: %w", err) } - sessionID, session, err := u.loadUserSession(echoCtx, *tenantDID) + sessionID, sessionData, err := u.loadUserSession(echoCtx, *tenantDID) if err != nil { // Should only really occur in exceptional circumstances (e.g. cookie survived after intended max age). log.Logger().WithError(err).Info("Invalid user session, a new session will be created") } - if session == nil { - session, err = createUserSession(*tenantDID, u.TimeOut) + if sessionData == nil { + sessionData, err = createUserSession(*tenantDID, u.TimeOut) sessionID = crypto.GenerateNonce() - if err := u.Store.Put(sessionID, session); err != nil { + if err := u.Store.Put(sessionID, sessionData); err != nil { return err } if err != nil { @@ -94,11 +94,11 @@ func (u Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc { // By scoping the cookie to a tenant (DID)-specific path, the user can have a session per tenant DID on the same domain. echoCtx.SetCookie(u.createUserSessionCookie(sessionID, u.CookiePath(*tenantDID))) } - session.Save = func() error { - return u.Store.Put(sessionID, session) + sessionData.Save = func() error { + return u.Store.Put(sessionID, sessionData) } // Session data is put in request context for access by API handlers - echoCtx.SetRequest(echoCtx.Request().WithContext(context.WithValue(echoCtx.Request().Context(), userSessionContextKey, session))) + echoCtx.SetRequest(echoCtx.Request().WithContext(context.WithValue(echoCtx.Request().Context(), userSessionContextKey, sessionData))) return next(echoCtx) } @@ -128,9 +128,7 @@ func (u Middleware) loadUserSession(cookies CookieReader, tenantDID did.DID) (st // but this adds less complexity. return "", nil, errors.New("expired session") } - // Note that the session itself does not have an expiration field: - // it depends on the session store to clean up when it expires. - if !session.TenantDID.Equals(tenantDID) && !session.Wallet.DID.Equals(tenantDID) { + if !session.TenantDID.Equals(tenantDID) { return "", nil, fmt.Errorf("session belongs to another tenant (%s)", session.TenantDID) } return sessionID, session, nil @@ -170,15 +168,12 @@ func (u Middleware) createUserSessionCookie(sessionID string, path string) *http } // Get retrieves the user session from the request context. -// If the user session is not found, or belongs to another tenant, an error is returned. -func Get(ctx context.Context, expectedTenantDID did.DID) (*Data, error) { +// If the user session is not found, an error is returned. +func Get(ctx context.Context) (*Data, error) { result, ok := ctx.Value(userSessionContextKey).(*Data) if !ok { return nil, errors.New("no user session found") } - if result.TenantDID.String() != expectedTenantDID.String() { - return nil, errors.New("user session belongs to another tenant") - } return result, nil } diff --git a/auth/api/iam/usersession/user_session_test.go b/auth/api/iam/usersession/user_session_test.go index ac5ecb7c8f..fc6d9cae6f 100644 --- a/auth/api/iam/usersession/user_session_test.go +++ b/auth/api/iam/usersession/user_session_test.go @@ -21,6 +21,7 @@ package usersession import ( "github.com/labstack/echo/v4" "github.com/nuts-foundation/go-did/did" + "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -64,7 +65,7 @@ func TestMiddleware_Handle(t *testing.T) { var capturedSession *Data err := instance.Handle(func(c echo.Context) error { var err error - capturedSession, err = Get(c.Request().Context(), tenantDID) + capturedSession, err = Get(c.Request().Context()) return err })(echoContext) @@ -91,9 +92,9 @@ func TestMiddleware_Handle(t *testing.T) { var capturedSession *Data err := instance.Handle(func(c echo.Context) error { - var err error - capturedSession, err = Get(c.Request().Context(), tenantDID) - return err + capturedSession, _ = Get(c.Request().Context()) + capturedSession.Wallet.Credentials = append(capturedSession.Wallet.Credentials, vc.VerifiableCredential{}) + return capturedSession.Save() })(echoContext) assert.NoError(t, err) @@ -162,7 +163,7 @@ func TestMiddleware_Handle(t *testing.T) { var capturedSession *Data err := instance.Handle(func(c echo.Context) error { var err error - capturedSession, err = Get(c.Request().Context(), tenantDID) + capturedSession, err = Get(c.Request().Context()) return err })(echoContext)