Skip to content

Commit

Permalink
fix multi-tenancy
Browse files Browse the repository at this point in the history
  • Loading branch information
reinkrul committed May 25, 2024
1 parent cc818c8 commit 1fef2a8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 33 deletions.
4 changes: 1 addition & 3 deletions auth/api/iam/openid4vp.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,7 @@ func (r Wrapper) handleAuthorizeRequestFromVerifier(ctx context.Context, tenantD
return r.sendAndHandleDirectPostError(ctx, oauth.OAuth2Error{Code: oauth.InvalidRequest, Description: "missing nonce parameter"}, responseURI, state)
}

// TODO: Create session if it does not exist (use client state to get original Authorization Code request)?
// Although it would be quite weird (maybe it expired).
userSession, err := usersession.Get(ctx, tenantDID)
userSession, err := usersession.Get(ctx)
if userSession == nil {
return nil, oauth.OAuth2Error{Code: oauth.InvalidRequest, InternalError: err, Description: "no user session found"}
}
Expand Down
11 changes: 5 additions & 6 deletions auth/api/iam/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ func (r Wrapper) handleUserLanding(echoCtx echo.Context) error {
}

// Make sure there's a user session, loaded with EmployeeCredential
tenantDID, _ := did.ParseDID(accessTokenRequest.Did) // can't fail, since the request was created earlier (thus, validated)
userSession, err := usersession.Get(echoCtx.Request().Context(), *tenantDID)
userSession, err := usersession.Get(echoCtx.Request().Context())
if err != nil {
return err
}
Expand Down Expand Up @@ -163,18 +162,18 @@ func (r Wrapper) provisionUserSession(ctx context.Context, session *usersession.
return session.Save()
}

func (r Wrapper) issueEmployeeCredential(ctx context.Context, data usersession.Data, userDetails UserDetails) (*vc.VerifiableCredential, error) {
func (r Wrapper) issueEmployeeCredential(ctx context.Context, session usersession.Data, userDetails UserDetails) (*vc.VerifiableCredential, error) {
issuanceDate := time.Now()
expirationDate := data.ExpiresAt
expirationDate := session.ExpiresAt
template := vc.VerifiableCredential{
Context: []ssi.URI{credential.NutsV1ContextURI},
Type: []ssi.URI{ssi.MustParseURI("EmployeeCredential")},
Issuer: data.TenantDID.URI(),
Issuer: session.TenantDID.URI(),
IssuanceDate: issuanceDate,
ExpirationDate: &expirationDate,
CredentialSubject: []interface{}{
map[string]string{
"id": data.Wallet.DID.String(),
"id": session.Wallet.DID.String(),
"identifier": userDetails.Id,
"name": userDetails.Name,
"roleName": userDetails.Role,
Expand Down
33 changes: 14 additions & 19 deletions auth/api/iam/usersession/user_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ import (
var userSessionContextKey = struct{}{}

// userSessionCookieName is the name of the cookie used to store the user session.
// It uses the __Host prefix, that instructs the user agent to treat it as a secure cookie:
// It uses the __Secure prefix, that instructs the user agent to treat it as a secure cookie:
// - Must be set with the Secure attribute
// - Must be set from an HTTPS uri
// - Must not contain a Domain attribute
// - Must contain a Path attribute
// Note that earlier, we used the Host cookie prefix, but that doesn't work in a multi-tenant environment,
// since then the Path attribute (used for multi-tenancy) can't be used.
// Also see:
// - https://owasp.org/www-project-web-security-testing-guide/latest/4-Web_Application_Security_Testing/06-Session_Management_Testing/02-Testing_for_Cookies_Attributes
// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies
const userSessionCookieName = "__Host-SID"
const userSessionCookieName = "__Secure-SID"

// Middleware is Echo middleware that ensures a user session is available in the request context (unless skipped).
// If no session is available, a new session is created.
Expand Down Expand Up @@ -77,15 +77,15 @@ func (u Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc {
return fmt.Errorf("invalid tenant DID: %w", err)
}

sessionID, session, err := u.loadUserSession(echoCtx, *tenantDID)
sessionID, sessionData, err := u.loadUserSession(echoCtx, *tenantDID)
if err != nil {
// Should only really occur in exceptional circumstances (e.g. cookie survived after intended max age).
log.Logger().WithError(err).Info("Invalid user session, a new session will be created")
}
if session == nil {
session, err = createUserSession(*tenantDID, u.TimeOut)
if sessionData == nil {
sessionData, err = createUserSession(*tenantDID, u.TimeOut)
sessionID = crypto.GenerateNonce()
if err := u.Store.Put(sessionID, session); err != nil {
if err := u.Store.Put(sessionID, sessionData); err != nil {
return err
}
if err != nil {
Expand All @@ -94,11 +94,11 @@ func (u Middleware) Handle(next echo.HandlerFunc) echo.HandlerFunc {
// By scoping the cookie to a tenant (DID)-specific path, the user can have a session per tenant DID on the same domain.
echoCtx.SetCookie(u.createUserSessionCookie(sessionID, u.CookiePath(*tenantDID)))
}
session.Save = func() error {
return u.Store.Put(sessionID, session)
sessionData.Save = func() error {
return u.Store.Put(sessionID, sessionData)
}
// Session data is put in request context for access by API handlers
echoCtx.SetRequest(echoCtx.Request().WithContext(context.WithValue(echoCtx.Request().Context(), userSessionContextKey, session)))
echoCtx.SetRequest(echoCtx.Request().WithContext(context.WithValue(echoCtx.Request().Context(), userSessionContextKey, sessionData)))

return next(echoCtx)
}
Expand Down Expand Up @@ -128,9 +128,7 @@ func (u Middleware) loadUserSession(cookies CookieReader, tenantDID did.DID) (st
// but this adds less complexity.
return "", nil, errors.New("expired session")
}
// Note that the session itself does not have an expiration field:
// it depends on the session store to clean up when it expires.
if !session.TenantDID.Equals(tenantDID) && !session.Wallet.DID.Equals(tenantDID) {
if !session.TenantDID.Equals(tenantDID) {
return "", nil, fmt.Errorf("session belongs to another tenant (%s)", session.TenantDID)
}
return sessionID, session, nil
Expand Down Expand Up @@ -170,15 +168,12 @@ func (u Middleware) createUserSessionCookie(sessionID string, path string) *http
}

// Get retrieves the user session from the request context.
// If the user session is not found, or belongs to another tenant, an error is returned.
func Get(ctx context.Context, expectedTenantDID did.DID) (*Data, error) {
// If the user session is not found, an error is returned.
func Get(ctx context.Context) (*Data, error) {
result, ok := ctx.Value(userSessionContextKey).(*Data)
if !ok {
return nil, errors.New("no user session found")
}
if result.TenantDID.String() != expectedTenantDID.String() {
return nil, errors.New("user session belongs to another tenant")
}
return result, nil
}

Expand Down
11 changes: 6 additions & 5 deletions auth/api/iam/usersession/user_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package usersession
import (
"github.com/labstack/echo/v4"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -64,7 +65,7 @@ func TestMiddleware_Handle(t *testing.T) {
var capturedSession *Data
err := instance.Handle(func(c echo.Context) error {
var err error
capturedSession, err = Get(c.Request().Context(), tenantDID)
capturedSession, err = Get(c.Request().Context())
return err
})(echoContext)

Expand All @@ -91,9 +92,9 @@ func TestMiddleware_Handle(t *testing.T) {

var capturedSession *Data
err := instance.Handle(func(c echo.Context) error {
var err error
capturedSession, err = Get(c.Request().Context(), tenantDID)
return err
capturedSession, _ = Get(c.Request().Context())
capturedSession.Wallet.Credentials = append(capturedSession.Wallet.Credentials, vc.VerifiableCredential{})
return capturedSession.Save()
})(echoContext)

assert.NoError(t, err)
Expand Down Expand Up @@ -162,7 +163,7 @@ func TestMiddleware_Handle(t *testing.T) {
var capturedSession *Data
err := instance.Handle(func(c echo.Context) error {
var err error
capturedSession, err = Get(c.Request().Context(), tenantDID)
capturedSession, err = Get(c.Request().Context())
return err
})(echoContext)

Expand Down

0 comments on commit 1fef2a8

Please sign in to comment.