diff --git a/services/proxy/.mockery.yaml b/services/proxy/.mockery.yaml index a490457301..d3ae0a3817 100644 --- a/services/proxy/.mockery.yaml +++ b/services/proxy/.mockery.yaml @@ -12,3 +12,8 @@ packages: github.com/opencloud-eu/opencloud/services/proxy/pkg/userroles: interfaces: UserRoleAssigner: {} + go-micro.dev/v4/store: + config: + dir: pkg/staticroutes/internal/backchannellogout/mocks + interfaces: + Store: {} diff --git a/services/proxy/pkg/command/server.go b/services/proxy/pkg/command/server.go index 25c1d045dc..9209ec30fb 100644 --- a/services/proxy/pkg/command/server.go +++ b/services/proxy/pkg/command/server.go @@ -10,6 +10,7 @@ import ( gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1" "github.com/justinas/alice" + "github.com/opencloud-eu/opencloud/pkg/config/configlog" "github.com/opencloud-eu/opencloud/pkg/generators" "github.com/opencloud-eu/opencloud/pkg/log" @@ -72,6 +73,7 @@ func Server(cfg *config.Config) *cobra.Command { microstore.Nodes(cfg.PreSignedURL.SigningKeys.Nodes...), microstore.Database("proxy"), microstore.Table("signing-keys"), + store.DisablePersistence(cfg.PreSignedURL.SigningKeys.DisablePersistence), store.Authentication(cfg.PreSignedURL.SigningKeys.AuthUsername, cfg.PreSignedURL.SigningKeys.AuthPassword), ) diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 18e0ef3344..00822cdc99 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -8,13 +8,15 @@ import ( "strings" "time" - "github.com/opencloud-eu/opencloud/pkg/log" - "github.com/opencloud-eu/opencloud/pkg/oidc" "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" - store "go-micro.dev/v4/store" + "go-micro.dev/v4/store" "golang.org/x/crypto/sha3" "golang.org/x/oauth2" + + "github.com/opencloud-eu/opencloud/pkg/log" + "github.com/opencloud-eu/opencloud/pkg/oidc" + "github.com/opencloud-eu/opencloud/services/proxy/pkg/staticroutes" ) const ( @@ -114,16 +116,25 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri m.Logger.Error().Err(err).Msg("failed to write to userinfo cache") } - if sid := aClaims.SessionID; sid != "" { - // reuse user cache for session id lookup - err = m.userInfoCache.Write(&store.Record{ - Key: sid, - Value: []byte(encodedHash), - Expiry: time.Until(expiration), - }) - if err != nil { - m.Logger.Error().Err(err).Msg("failed to write session lookup cache") - } + // fail if creating the storage key fails, + // it means there is no subject and no session. + // + // ok: {key: ".sessionId"} + // ok: {key: "subject."} + // ok: {key: "subject.sessionId"} + // fail: {key: "."} + subjectSessionKey, err := staticroutes.NewRecordKey(aClaims.Subject, aClaims.SessionID) + if err != nil { + m.Logger.Error().Err(err).Msg("failed to build subject.session") + return + } + + if err := m.userInfoCache.Write(&store.Record{ + Key: subjectSessionKey, + Value: []byte(encodedHash), + Expiry: time.Until(expiration), + }); err != nil { + m.Logger.Error().Err(err).Msg("failed to write session lookup cache") } } }() diff --git a/services/proxy/pkg/staticroutes/backchannellogout.go b/services/proxy/pkg/staticroutes/backchannellogout.go index 0c67a4d951..53375d63bb 100644 --- a/services/proxy/pkg/staticroutes/backchannellogout.go +++ b/services/proxy/pkg/staticroutes/backchannellogout.go @@ -6,17 +6,40 @@ import ( "net/http" "github.com/go-chi/render" - "github.com/opencloud-eu/opencloud/pkg/oidc" - "github.com/opencloud-eu/reva/v2/pkg/events" - "github.com/opencloud-eu/reva/v2/pkg/utils" "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" microstore "go-micro.dev/v4/store" + + bcl "github.com/opencloud-eu/opencloud/services/proxy/pkg/staticroutes/internal/backchannellogout" + "github.com/opencloud-eu/reva/v2/pkg/events" + "github.com/opencloud-eu/reva/v2/pkg/utils" ) -// handle backchannel logout requests as per https://openid.net/specs/openid-connect-backchannel-1_0.html#BCRequest +// NewRecordKey converts the subject and session to a base64 encoded key +var NewRecordKey = bcl.NewKey + +// backchannelLogout handles backchannel logout requests from the identity provider and invalidates the related sessions in the cache +// spec: https://openid.net/specs/openid-connect-backchannel-1_0.html#BCRequest +// +// known side effects of backchannel logout in keycloak: +// +// - keyCloak "Sign out all active sessions" does not send a backchannel logout request, +// as the devs mention, this may lead to thousands of backchannel logout requests, +// therefore, they recommend a short token lifetime. +// https://github.com/keycloak/keycloak/issues/27342#issuecomment-2408461913 +// +// - keyCloak user self-service portal, "Sign out all devices" may not send a backchannel +// logout request for each session, it's not mentionex explicitly, +// but maybe the reason for that is the same as for "Sign out all active sessions" +// to prevent a flood of backchannel logout requests. +// +// - if the keycloak setting "Backchannel logout session required" is disabled (or the token has no session id), +// we resolve the session by the subject which can lead to multiple session records (subject.*), +// we then send a logout event (sse) to each connected client and delete our stored cache record (subject.session & claim). +// all sessions besides the one that triggered the backchannel logout continue to exist in the identity provider, +// so the user will not be fully logged out until all sessions are logged out or expired. +// this leads to the situation that web renders the logout view even if the instance is not fully logged out yet. func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Request) { - // parse the application/x-www-form-urlencoded POST request logger := s.Logger.SubloggerWithRequestID(r.Context()) if err := r.ParseForm(); err != nil { logger.Warn().Err(err).Msg("ParseForm failed") @@ -27,45 +50,84 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re logoutToken, err := s.OidcClient.VerifyLogoutToken(r.Context(), r.PostFormValue("logout_token")) if err != nil { - logger.Warn().Err(err).Msg("VerifyLogoutToken failed") + msg := "failed to verify logout token" + logger.Warn().Err(err).Msg(msg) render.Status(r, http.StatusBadRequest) - render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()}) + render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: msg}) + return + } + + lookupKey, err := bcl.NewKey(logoutToken.Subject, logoutToken.SessionId) + if err != nil { + msg := "failed to build key from logout token" + logger.Warn().Err(err).Msg(msg) + render.Status(r, http.StatusBadRequest) + render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: msg}) + return + } + + requestSubjectAndSession, err := bcl.NewSuSe(lookupKey) + if err != nil { + msg := "failed to build subjec.session from lookupKey" + logger.Error().Err(err).Msg(msg) + render.Status(r, http.StatusBadRequest) + render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: msg}) return } - records, err := s.UserInfoCache.Read(logoutToken.SessionId) - if errors.Is(err, microstore.ErrNotFound) || len(records) == 0 { + lookupRecords, err := bcl.GetLogoutRecords(requestSubjectAndSession, s.UserInfoCache) + if errors.Is(err, microstore.ErrNotFound) || len(lookupRecords) == 0 { render.Status(r, http.StatusOK) render.JSON(w, r, nil) return } if err != nil { - logger.Error().Err(err).Msg("Error reading userinfo cache") + msg := "failed to read userinfo cache" + logger.Error().Err(err).Msg(msg) render.Status(r, http.StatusBadRequest) - render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()}) + render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: msg}) return } - for _, record := range records { - err := s.publishBackchannelLogoutEvent(r.Context(), record, logoutToken) + for _, record := range lookupRecords { + // the record key is in the format "subject.session" or ".session" + // the record value is the key of the record that contains the claim in its value + key, value := record.Key, string(record.Value) + + subjectSession, err := bcl.NewSuSe(key) + if err != nil { + // never leak any key-related information + logger.Warn().Err(err).Msgf("failed to parse key: %s", key) + continue + } + + session, err := subjectSession.Session() if err != nil { - s.Logger.Warn().Err(err).Msg("could not publish backchannel logout event") + logger.Warn().Err(err).Msgf("failed to read session for: %s", key) + continue + } + + if err := s.publishBackchannelLogoutEvent(r.Context(), session, value); err != nil { + s.Logger.Warn().Err(err).Msgf("failed to publish backchannel logout event for: %s", key) + continue } - err = s.UserInfoCache.Delete(string(record.Value)) + + err = s.UserInfoCache.Delete(value) if err != nil && !errors.Is(err, microstore.ErrNotFound) { - // Spec requires us to return a 400 BadRequest when the session could not be destroyed - logger.Err(err).Msg("could not delete user info from cache") + // we have to return a 400 BadRequest when we fail to delete the session + // https://openid.net/specs/openid-connect-backchannel-1_0.html#rfc.section.2.8 + msg := "failed to delete record" + s.Logger.Warn().Err(err).Msgf("%s for: %s", msg, key) render.Status(r, http.StatusBadRequest) - render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()}) + render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: msg}) return } - logger.Debug().Msg("Deleted userinfo from cache") - } - // we can ignore errors when cleaning up the lookup table - err = s.UserInfoCache.Delete(logoutToken.SessionId) - if err != nil { - logger.Debug().Err(err).Msg("Failed to cleanup sessionid lookup entry") + // we can ignore errors when deleting the lookup record + err = s.UserInfoCache.Delete(key) + if err != nil { + logger.Debug().Err(err).Msgf("failed to delete record for: %s", key) + } } render.Status(r, http.StatusOK) @@ -73,41 +135,42 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re } // publishBackchannelLogoutEvent publishes a backchannel logout event when the callback revived from the identity provider -func (s StaticRouteHandler) publishBackchannelLogoutEvent(ctx context.Context, record *microstore.Record, logoutToken *oidc.LogoutToken) error { +func (s *StaticRouteHandler) publishBackchannelLogoutEvent(ctx context.Context, sessionId, claimKey string) error { if s.EventsPublisher == nil { - return fmt.Errorf("the events publisher is not set") - } - urecords, err := s.UserInfoCache.Read(string(record.Value)) - if err != nil { - return fmt.Errorf("reading userinfo cache: %w", err) + return errors.New("events publisher not set") } - if len(urecords) == 0 { - return fmt.Errorf("userinfo not found") + + claimRecords, err := s.UserInfoCache.Read(claimKey) + switch { + case err != nil: + return fmt.Errorf("failed to read userinfo cache: %w", err) + case len(claimRecords) == 0: + return fmt.Errorf("no claim found for key: %s", claimKey) } var claims map[string]interface{} - if err = msgpack.Unmarshal(urecords[0].Value, &claims); err != nil { - return fmt.Errorf("could not unmarshal userinfo: %w", err) + if err = msgpack.Unmarshal(claimRecords[0].Value, &claims); err != nil { + return fmt.Errorf("failed to unmarshal claims: %w", err) } oidcClaim, ok := claims[s.Config.UserOIDCClaim].(string) if !ok { - return fmt.Errorf("could not get claim %w", err) + return fmt.Errorf("failed to get claim %w", err) } user, _, err := s.UserProvider.GetUserByClaims(ctx, s.Config.UserCS3Claim, oidcClaim) if err != nil || user.GetId() == nil { - return fmt.Errorf("could not get user by claims: %w", err) + return fmt.Errorf("failed to get user by claims: %w", err) } e := events.BackchannelLogout{ Executant: user.GetId(), - SessionId: logoutToken.SessionId, + SessionId: sessionId, Timestamp: utils.TSNow(), } if err := events.Publish(ctx, s.EventsPublisher, e); err != nil { - return fmt.Errorf("could not publish user created event %w", err) + return fmt.Errorf("failed to publish user logout event %w", err) } return nil } diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go new file mode 100644 index 0000000000..86ee00556b --- /dev/null +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go @@ -0,0 +1,188 @@ +// package backchannellogout provides functions to classify and lookup +// backchannel logout records from the cache store. + +package backchannellogout + +import ( + "encoding/base64" + "errors" + "strings" + + microstore "go-micro.dev/v4/store" +) + +// keyEncoding is the base64 encoding used for session and subject keys +var keyEncoding = base64.URLEncoding + +// ErrInvalidKey indicates that the provided key does not conform to the expected format. +var ErrInvalidKey = errors.New("invalid key format") + +// NewKey converts the subject and session to a base64 encoded key +func NewKey(subject, session string) (string, error) { + subjectSession := strings.Join([]string{ + keyEncoding.EncodeToString([]byte(subject)), + keyEncoding.EncodeToString([]byte(session)), + }, ".") + + if subjectSession == "." { + return "", ErrInvalidKey + } + + return subjectSession, nil +} + +// ErrDecoding is returned when decoding fails +var ErrDecoding = errors.New("failed to decode") + +// SuSe 🦎 ;) is a struct that groups the subject and session together +// to prevent mix-ups for ('session, subject' || 'subject, session') +// return values. +type SuSe struct { + encodedSubject string + encodedSession string +} + +// Subject decodes and returns the subject or an error +func (suse SuSe) Subject() (string, error) { + subject, err := keyEncoding.DecodeString(suse.encodedSubject) + if err != nil { + return "", errors.Join(errors.New("failed to decode subject"), ErrDecoding, err) + } + + return string(subject), nil +} + +// Session decodes and returns the session or an error +func (suse SuSe) Session() (string, error) { + subject, err := keyEncoding.DecodeString(suse.encodedSession) + if err != nil { + return "", errors.Join(errors.New("failed to decode session"), ErrDecoding, err) + } + + return string(subject), nil +} + +// ErrInvalidSubjectOrSession is returned when the provided key does not match the expected key format +var ErrInvalidSubjectOrSession = errors.New("invalid subject or session") + +// NewSuSe parses the subject and session id from the given key and returns a SuSe struct +func NewSuSe(key string) (SuSe, error) { + suse := SuSe{} + switch keys := strings.Split(strings.Join(strings.Fields(key), ""), "."); { + // key: '.session' + case len(keys) == 2 && keys[0] == "" && keys[1] != "": + suse.encodedSession = keys[1] + // key: 'subject.' + case len(keys) == 2 && keys[0] != "" && keys[1] == "": + suse.encodedSubject = keys[0] + // key: 'subject.session' + case len(keys) == 2 && keys[0] != "" && keys[1] != "": + suse.encodedSubject = keys[0] + suse.encodedSession = keys[1] + // key: 'session' + case len(keys) == 1 && keys[0] != "": + suse.encodedSession = keys[0] + default: + return suse, ErrInvalidSubjectOrSession + } + + if _, err := suse.Subject(); err != nil { + return suse, errors.Join(ErrInvalidSubjectOrSession, err) + } + + if _, err := suse.Session(); err != nil { + return suse, errors.Join(ErrInvalidSubjectOrSession, err) + } + + return suse, nil +} + +// logoutMode defines the mode of backchannel logout, either by session or by subject +type logoutMode int + +const ( + // logoutModeUndefined is used when the logout mode cannot be determined + logoutModeUndefined logoutMode = iota + // logoutModeSubject is used when the logout mode is determined by the subject + logoutModeSubject + // logoutModeSession is used when the logout mode is determined by the session id + logoutModeSession +) + +// getLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct +func getLogoutMode(suse SuSe) logoutMode { + switch { + case suse.encodedSession == "" && suse.encodedSubject != "": + return logoutModeSubject + case suse.encodedSession != "": + return logoutModeSession + default: + return logoutModeUndefined + } +} + +// ErrSuspiciousCacheResult is returned when the cache result is suspicious +var ErrSuspiciousCacheResult = errors.New("suspicious cache result") + +// GetLogoutRecords retrieves the records from the user info cache based on the backchannel +// logout mode and the provided SuSe struct. +// it uses a seperator to prevent sufix and prefix exploration in the cache and checks +// if the retrieved records match the requested subject and or session id as well, to prevent false positives. +func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, error) { + // get subject.session mode + mode := getLogoutMode(suse) + + var key string + var opts []microstore.ReadOption + switch mode { + case logoutModeSubject: + // the dot at the end prevents prefix exploration in the cache, + // so only keys that start with 'subject.*' will be returned, but not 'sub*'. + key = suse.encodedSubject + "." + opts = append(opts, microstore.ReadPrefix()) + case logoutModeSession: + // the dot at the beginning prevents sufix exploration in the cache, + // so only keys that end with '*.session' will be returned, but not '*sion'. + key = "." + suse.encodedSession + opts = append(opts, microstore.ReadSuffix()) + default: + return nil, errors.Join(errors.New("cannot determine logout mode"), ErrSuspiciousCacheResult) + } + + // the go micro memory store requires a limit to work, why??? + records, err := store.Read(key, append(opts, microstore.ReadLimit(1000))...) + if err != nil { + return nil, err + } + + if len(records) == 0 { + return nil, microstore.ErrNotFound + } + + if mode == logoutModeSession && len(records) > 1 { + return nil, errors.Join(errors.New("multiple session records found"), ErrSuspiciousCacheResult) + } + + // double-check if the found records match the requested subject and or session id as well, + // to prevent false positives. + for _, record := range records { + recordSuSe, err := NewSuSe(record.Key) + if err != nil { + // never leak any key-related information + return nil, errors.Join(errors.New("failed to parse key"), ErrSuspiciousCacheResult, err) + } + + switch { + // in subject mode, the subject must match, but the session id can be different + case mode == logoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject: + continue + // in session mode, the session id must match, but the subject can be different + case mode == logoutModeSession && suse.encodedSession == recordSuSe.encodedSession: + continue + } + + return nil, errors.Join(errors.New("key does not match the requested subject or session"), ErrSuspiciousCacheResult) + } + + return records, nil +} diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go new file mode 100644 index 0000000000..617bd6d9e0 --- /dev/null +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go @@ -0,0 +1,331 @@ +package backchannellogout + +import ( + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go-micro.dev/v4/store" + + "github.com/opencloud-eu/opencloud/services/proxy/pkg/staticroutes/internal/backchannellogout/mocks" +) + +func mustNewKey(t *testing.T, subject, session string) string { + key, err := NewKey(subject, session) + require.NoError(t, err) + return key +} + +func mustNewSuSe(t *testing.T, subject, session string) SuSe { + suse, err := NewSuSe(mustNewKey(t, subject, session)) + require.NoError(t, err) + return suse +} + +func TestNewKey(t *testing.T) { + tests := []struct { + name string + subject string + session string + wantKey string + wantErr error + }{ + { + name: "key variation: 'subject.session'", + subject: "subject", + session: "session", + wantKey: "c3ViamVjdA==.c2Vzc2lvbg==", + }, + { + name: "key variation: 'subject.'", + subject: "subject", + wantKey: "c3ViamVjdA==.", + }, + { + name: "key variation: '.session'", + session: "session", + wantKey: ".c2Vzc2lvbg==", + }, + { + name: "key variation: '.'", + wantErr: ErrInvalidKey, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key, err := NewKey(tt.subject, tt.session) + require.ErrorIs(t, err, tt.wantErr) + require.Equal(t, tt.wantKey, key) + }) + } +} + +func TestNewSuSe(t *testing.T) { + tests := []struct { + name string + key string + wantSubject string + wantSession string + wantErr error + }{ + { + name: "key variation: '.session'", + key: mustNewKey(t, "", "session"), + wantSession: "session", + }, + { + name: "key variation: 'session'", + key: mustNewKey(t, "", "session"), + wantSession: "session", + }, + { + name: "key variation: 'subject.'", + key: mustNewKey(t, "subject", ""), + wantSubject: "subject", + }, + { + name: "key variation: 'subject.session'", + key: mustNewKey(t, "subject", "session"), + wantSubject: "subject", + wantSession: "session", + }, + { + name: "key variation: 'dot'", + key: ".", + wantErr: ErrInvalidSubjectOrSession, + }, + { + name: "key variation: 'empty'", + key: "", + wantErr: ErrInvalidSubjectOrSession, + }, + { + name: "key variation: string('subject.session')", + key: "subject.session", + wantErr: ErrInvalidSubjectOrSession, + }, + { + name: "key variation: string('subject.')", + key: "subject.", + wantErr: ErrInvalidSubjectOrSession, + }, + { + name: "key variation: string('.session')", + key: ".session", + wantErr: ErrInvalidSubjectOrSession, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + suSe, err := NewSuSe(tt.key) + require.ErrorIs(t, err, tt.wantErr) + + subject, _ := suSe.Subject() + require.Equal(t, tt.wantSubject, subject) + + session, _ := suSe.Session() + require.Equal(t, tt.wantSession, session) + }) + } +} + +func TestGetLogoutMode(t *testing.T) { + tests := []struct { + name string + suSe SuSe + want logoutMode + }{ + { + name: "key variation: '.session'", + suSe: mustNewSuSe(t, "", "session"), + want: logoutModeSession, + }, + { + name: "key variation: 'subject.session'", + suSe: mustNewSuSe(t, "subject", "session"), + want: logoutModeSession, + }, + { + name: "key variation: 'subject.'", + suSe: mustNewSuSe(t, "subject", ""), + want: logoutModeSubject, + }, + { + name: "key variation: 'empty'", + suSe: SuSe{}, + want: logoutModeUndefined, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mode := getLogoutMode(tt.suSe) + require.Equal(t, tt.want, mode) + }) + } +} + +func TestGetLogoutRecords(t *testing.T) { + sessionStore := store.NewMemoryStore() + + recordClaimA := &store.Record{Key: "claim-a", Value: []byte("claim-a-data")} + recordClaimB := &store.Record{Key: "claim-b", Value: []byte("claim-b-data")} + recordClaimC := &store.Record{Key: "claim-c", Value: []byte("claim-c-data")} + recordClaimD := &store.Record{Key: "claim-d", Value: []byte("claim-d-data")} + recordSessionA := &store.Record{Key: mustNewKey(t, "", "session-a"), Value: []byte(recordClaimA.Key)} + recordSessionB := &store.Record{Key: mustNewKey(t, "", "session-b"), Value: []byte(recordClaimB.Key)} + recordSubjectASessionC := &store.Record{Key: mustNewKey(t, "subject-a", "session-c"), Value: []byte(recordSessionA.Key)} + recordSubjectASessionD := &store.Record{Key: mustNewKey(t, "subject-a", "session-d"), Value: []byte(recordSessionA.Key)} + + for _, r := range []*store.Record{ + recordClaimA, + recordClaimB, + recordClaimC, + recordClaimD, + recordSessionA, + recordSessionB, + recordSubjectASessionC, + recordSubjectASessionD, + } { + require.NoError(t, sessionStore.Write(r)) + } + + tests := []struct { + name string + suSe SuSe + store func(t *testing.T) store.Store + wantRecords []*store.Record + wantErrs []error + }{ + { + name: "fails if multiple session records are found", + suSe: mustNewSuSe(t, "", "session-a"), + store: func(t *testing.T) store.Store { + s := mocks.NewStore(t) + s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ + recordSessionA, + recordSessionB, + }, nil) + return s + }, + wantRecords: []*store.Record{}, + wantErrs: []error{ErrSuspiciousCacheResult}}, + { + name: "fails if the record key is not ok", + suSe: mustNewSuSe(t, "", "session-a"), + store: func(t *testing.T) store.Store { + s := mocks.NewStore(t) + s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ + {Key: "invalid.record.key"}, + }, nil) + return s + }, + wantRecords: []*store.Record{}, + wantErrs: []error{ErrInvalidSubjectOrSession, ErrSuspiciousCacheResult}, + }, + { + name: "fails if the session does not match the retrieved record", + suSe: mustNewSuSe(t, "", "session-a"), + store: func(t *testing.T) store.Store { + s := mocks.NewStore(t) + s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ + recordSessionB, + }, nil) + return s + }, + wantRecords: []*store.Record{}, + wantErrs: []error{ErrSuspiciousCacheResult}}, + { + name: "fails if the subject does not match the retrieved record", + suSe: mustNewSuSe(t, "subject-a", ""), + store: func(t *testing.T) store.Store { + s := mocks.NewStore(t) + s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ + recordSessionB, + }, nil) + return s + }, + wantRecords: []*store.Record{}, + wantErrs: []error{ErrSuspiciousCacheResult}}, + // key variation tests + { + name: "key variation: 'session-a'", + suSe: mustNewSuSe(t, "", "session-a"), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{recordSessionA}, + }, + { + name: "key variation: 'session-b'", + suSe: mustNewSuSe(t, "", "session-b"), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{recordSessionB}, + }, + { + name: "key variation: 'session-c'", + suSe: mustNewSuSe(t, "", "session-c"), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{recordSubjectASessionC}, + }, + { + name: "key variation: 'ession-c'", + suSe: mustNewSuSe(t, "", "ession-c"), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{}, + wantErrs: []error{store.ErrNotFound}, + }, + { + name: "key variation: 'subject-a'", + suSe: mustNewSuSe(t, "subject-a", ""), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{recordSubjectASessionC, recordSubjectASessionD}, + }, + { + name: "key variation: 'subject-'", + suSe: mustNewSuSe(t, "subject-", ""), + store: func(*testing.T) store.Store { + return sessionStore + }, + wantRecords: []*store.Record{}, + wantErrs: []error{store.ErrNotFound}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + records, err := GetLogoutRecords(tt.suSe, tt.store(t)) + for _, wantErr := range tt.wantErrs { + require.ErrorIs(t, err, wantErr) + } + require.Len(t, records, len(tt.wantRecords)) + + sortRecords := func(r []*store.Record) []*store.Record { + slices.SortFunc(r, func(a, b *store.Record) int { + return strings.Compare(a.Key, b.Key) + }) + + return r + } + + records = sortRecords(records) + for i, wantRecords := range sortRecords(tt.wantRecords) { + require.True(t, len(records) >= i+1) + require.Equal(t, wantRecords.Key, records[i].Key) + require.Equal(t, wantRecords.Value, records[i].Value) + } + }) + } +} diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/mocks/store.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/mocks/store.go new file mode 100644 index 0000000000..359ea9cc2b --- /dev/null +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/mocks/store.go @@ -0,0 +1,509 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + "go-micro.dev/v4/store" +) + +// NewStore creates a new instance of Store. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewStore(t interface { + mock.TestingT + Cleanup(func()) +}) *Store { + mock := &Store{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Store is an autogenerated mock type for the Store type +type Store struct { + mock.Mock +} + +type Store_Expecter struct { + mock *mock.Mock +} + +func (_m *Store) EXPECT() *Store_Expecter { + return &Store_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type Store +func (_mock *Store) Close() error { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func() error); ok { + r0 = returnFunc() + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Store_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Store_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Store_Expecter) Close() *Store_Close_Call { + return &Store_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Store_Close_Call) Run(run func()) *Store_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Store_Close_Call) Return(err error) *Store_Close_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Store_Close_Call) RunAndReturn(run func() error) *Store_Close_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function for the type Store +func (_mock *Store) Delete(key string, opts ...store.DeleteOption) error { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(key, opts) + } else { + tmpRet = _mock.Called(key) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string, ...store.DeleteOption) error); ok { + r0 = returnFunc(key, opts...) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Store_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Store_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - key string +// - opts ...store.DeleteOption +func (_e *Store_Expecter) Delete(key interface{}, opts ...interface{}) *Store_Delete_Call { + return &Store_Delete_Call{Call: _e.mock.On("Delete", + append([]interface{}{key}, opts...)...)} +} + +func (_c *Store_Delete_Call) Run(run func(key string, opts ...store.DeleteOption)) *Store_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []store.DeleteOption + var variadicArgs []store.DeleteOption + if len(args) > 1 { + variadicArgs = args[1].([]store.DeleteOption) + } + arg1 = variadicArgs + run( + arg0, + arg1..., + ) + }) + return _c +} + +func (_c *Store_Delete_Call) Return(err error) *Store_Delete_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Store_Delete_Call) RunAndReturn(run func(key string, opts ...store.DeleteOption) error) *Store_Delete_Call { + _c.Call.Return(run) + return _c +} + +// Init provides a mock function for the type Store +func (_mock *Store) Init(options ...store.Option) error { + var tmpRet mock.Arguments + if len(options) > 0 { + tmpRet = _mock.Called(options) + } else { + tmpRet = _mock.Called() + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Init") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(...store.Option) error); ok { + r0 = returnFunc(options...) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Store_Init_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Init' +type Store_Init_Call struct { + *mock.Call +} + +// Init is a helper method to define mock.On call +// - options ...store.Option +func (_e *Store_Expecter) Init(options ...interface{}) *Store_Init_Call { + return &Store_Init_Call{Call: _e.mock.On("Init", + append([]interface{}{}, options...)...)} +} + +func (_c *Store_Init_Call) Run(run func(options ...store.Option)) *Store_Init_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []store.Option + var variadicArgs []store.Option + if len(args) > 0 { + variadicArgs = args[0].([]store.Option) + } + arg0 = variadicArgs + run( + arg0..., + ) + }) + return _c +} + +func (_c *Store_Init_Call) Return(err error) *Store_Init_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Store_Init_Call) RunAndReturn(run func(options ...store.Option) error) *Store_Init_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function for the type Store +func (_mock *Store) List(opts ...store.ListOption) ([]string, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(opts) + } else { + tmpRet = _mock.Called() + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(...store.ListOption) ([]string, error)); ok { + return returnFunc(opts...) + } + if returnFunc, ok := ret.Get(0).(func(...store.ListOption) []string); ok { + r0 = returnFunc(opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(...store.ListOption) error); ok { + r1 = returnFunc(opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Store_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Store_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - opts ...store.ListOption +func (_e *Store_Expecter) List(opts ...interface{}) *Store_List_Call { + return &Store_List_Call{Call: _e.mock.On("List", + append([]interface{}{}, opts...)...)} +} + +func (_c *Store_List_Call) Run(run func(opts ...store.ListOption)) *Store_List_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []store.ListOption + var variadicArgs []store.ListOption + if len(args) > 0 { + variadicArgs = args[0].([]store.ListOption) + } + arg0 = variadicArgs + run( + arg0..., + ) + }) + return _c +} + +func (_c *Store_List_Call) Return(strings []string, err error) *Store_List_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Store_List_Call) RunAndReturn(run func(opts ...store.ListOption) ([]string, error)) *Store_List_Call { + _c.Call.Return(run) + return _c +} + +// Options provides a mock function for the type Store +func (_mock *Store) Options() store.Options { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Options") + } + + var r0 store.Options + if returnFunc, ok := ret.Get(0).(func() store.Options); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(store.Options) + } + return r0 +} + +// Store_Options_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Options' +type Store_Options_Call struct { + *mock.Call +} + +// Options is a helper method to define mock.On call +func (_e *Store_Expecter) Options() *Store_Options_Call { + return &Store_Options_Call{Call: _e.mock.On("Options")} +} + +func (_c *Store_Options_Call) Run(run func()) *Store_Options_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Store_Options_Call) Return(options store.Options) *Store_Options_Call { + _c.Call.Return(options) + return _c +} + +func (_c *Store_Options_Call) RunAndReturn(run func() store.Options) *Store_Options_Call { + _c.Call.Return(run) + return _c +} + +// Read provides a mock function for the type Store +func (_mock *Store) Read(key string, opts ...store.ReadOption) ([]*store.Record, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(key, opts) + } else { + tmpRet = _mock.Called(key) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Read") + } + + var r0 []*store.Record + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, ...store.ReadOption) ([]*store.Record, error)); ok { + return returnFunc(key, opts...) + } + if returnFunc, ok := ret.Get(0).(func(string, ...store.ReadOption) []*store.Record); ok { + r0 = returnFunc(key, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.Record) + } + } + if returnFunc, ok := ret.Get(1).(func(string, ...store.ReadOption) error); ok { + r1 = returnFunc(key, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Store_Read_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Read' +type Store_Read_Call struct { + *mock.Call +} + +// Read is a helper method to define mock.On call +// - key string +// - opts ...store.ReadOption +func (_e *Store_Expecter) Read(key interface{}, opts ...interface{}) *Store_Read_Call { + return &Store_Read_Call{Call: _e.mock.On("Read", + append([]interface{}{key}, opts...)...)} +} + +func (_c *Store_Read_Call) Run(run func(key string, opts ...store.ReadOption)) *Store_Read_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []store.ReadOption + var variadicArgs []store.ReadOption + if len(args) > 1 { + variadicArgs = args[1].([]store.ReadOption) + } + arg1 = variadicArgs + run( + arg0, + arg1..., + ) + }) + return _c +} + +func (_c *Store_Read_Call) Return(records []*store.Record, err error) *Store_Read_Call { + _c.Call.Return(records, err) + return _c +} + +func (_c *Store_Read_Call) RunAndReturn(run func(key string, opts ...store.ReadOption) ([]*store.Record, error)) *Store_Read_Call { + _c.Call.Return(run) + return _c +} + +// String provides a mock function for the type Store +func (_mock *Store) String() string { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for String") + } + + var r0 string + if returnFunc, ok := ret.Get(0).(func() string); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(string) + } + return r0 +} + +// Store_String_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'String' +type Store_String_Call struct { + *mock.Call +} + +// String is a helper method to define mock.On call +func (_e *Store_Expecter) String() *Store_String_Call { + return &Store_String_Call{Call: _e.mock.On("String")} +} + +func (_c *Store_String_Call) Run(run func()) *Store_String_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Store_String_Call) Return(s string) *Store_String_Call { + _c.Call.Return(s) + return _c +} + +func (_c *Store_String_Call) RunAndReturn(run func() string) *Store_String_Call { + _c.Call.Return(run) + return _c +} + +// Write provides a mock function for the type Store +func (_mock *Store) Write(r *store.Record, opts ...store.WriteOption) error { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(r, opts) + } else { + tmpRet = _mock.Called(r) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Write") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*store.Record, ...store.WriteOption) error); ok { + r0 = returnFunc(r, opts...) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Store_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write' +type Store_Write_Call struct { + *mock.Call +} + +// Write is a helper method to define mock.On call +// - r *store.Record +// - opts ...store.WriteOption +func (_e *Store_Expecter) Write(r interface{}, opts ...interface{}) *Store_Write_Call { + return &Store_Write_Call{Call: _e.mock.On("Write", + append([]interface{}{r}, opts...)...)} +} + +func (_c *Store_Write_Call) Run(run func(r *store.Record, opts ...store.WriteOption)) *Store_Write_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *store.Record + if args[0] != nil { + arg0 = args[0].(*store.Record) + } + var arg1 []store.WriteOption + var variadicArgs []store.WriteOption + if len(args) > 1 { + variadicArgs = args[1].([]store.WriteOption) + } + arg1 = variadicArgs + run( + arg0, + arg1..., + ) + }) + return _c +} + +func (_c *Store_Write_Call) Return(err error) *Store_Write_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Store_Write_Call) RunAndReturn(run func(r *store.Record, opts ...store.WriteOption) error) *Store_Write_Call { + _c.Call.Return(run) + return _c +}