From 7223f7cc4d669f9154fd87bee3ef21ba0e584a40 Mon Sep 17 00:00:00 2001 From: Shivansh Vij Date: Thu, 12 Jan 2023 16:24:17 -0500 Subject: [PATCH] Add Foreign Key Constraints to Sessions (#4) * Updating required schema * Updating schema for storage to be more streamlined * Fixing minor bugs and renaming functions/interfaces * Making manager more streamlined by only exposing required functions * Fixing and streamlining bugs around flow verification * Updating generated code * Updating generated code --- auth.go | 30 +- pkg/api/v1/docs/api_docs.go | 34 +- pkg/api/v1/docs/api_swagger.json | 34 +- pkg/api/v1/docs/api_swagger.yaml | 23 +- pkg/api/v1/github/github.go | 21 +- pkg/api/v1/models/models.go | 14 +- pkg/api/v1/servicekey/servicekey.go | 23 +- pkg/api/v1/v1.go | 2 +- pkg/database/database.go | 74 +-- pkg/database/device.go | 77 ++++ pkg/database/github.go | 49 ++ pkg/manager/manager.go | 430 +++++++++--------- pkg/provider/device/database.go | 1 + pkg/provider/device/device.go | 13 + pkg/provider/github/github.go | 2 +- .../servicesession.go} | 23 +- pkg/storage/apikey.go | 54 +++ pkg/storage/registration.go | 42 ++ pkg/storage/secretkey.go | 54 +++ pkg/storage/servicekey.go | 34 ++ pkg/storage/servicesession.go | 66 +++ pkg/storage/session.go | 73 +++ pkg/storage/storage.go | 180 +------- pkg/storage/user.go | 46 ++ 24 files changed, 882 insertions(+), 517 deletions(-) create mode 100644 pkg/database/device.go create mode 100644 pkg/database/github.go rename pkg/{servicekey/session.go => servicesession/servicesession.go} (71%) create mode 100644 pkg/storage/apikey.go create mode 100644 pkg/storage/registration.go create mode 100644 pkg/storage/secretkey.go create mode 100644 pkg/storage/servicekey.go create mode 100644 pkg/storage/servicesession.go create mode 100644 pkg/storage/session.go create mode 100644 pkg/storage/user.go diff --git a/auth.go b/auth.go index 84d1384..432c3ca 100644 --- a/auth.go +++ b/auth.go @@ -17,30 +17,30 @@ package auth const ( - APIKeyPrefixString = "AK-" - ServiceKeyPrefixString = "SK-" - ServiceKeySessionPrefixString = "SS-" + APIKeyPrefixString = "AK-" + ServiceKeyPrefixString = "SK-" + ServiceSessionPrefixString = "SS-" ) var ( - APIKeyPrefix = []byte(APIKeyPrefixString) - ServiceKeySessionPrefix = []byte(ServiceKeySessionPrefixString) + APIKeyPrefix = []byte(APIKeyPrefixString) + ServiceKeyPrefix = []byte(ServiceKeyPrefixString) + ServiceSessionPrefix = []byte(ServiceSessionPrefixString) ) const ( - SessionContextKey = "session" - APIKeyContextKey = "apikey" - ServiceKeySessionContextKey = "service" - UserContextKey = "user" - OrganizationContextKey = "organization" + SessionContextKey = "session" + APIKeyContextKey = "apikey" + ServiceSessionContextKey = "service" + UserContextKey = "user" + OrganizationContextKey = "organization" + KindContextKey = "kind" ) type Kind string const ( - KindContextKey Kind = "kind" - - KindSession Kind = "session" - KindAPIKey Kind = "api" - KindServiceKey Kind = "service" + KindSession Kind = "session" + KindAPIKey Kind = "api" + KindServiceSession Kind = "service" ) diff --git a/pkg/api/v1/docs/api_docs.go b/pkg/api/v1/docs/api_docs.go index 56eb06b..51eef48 100644 --- a/pkg/api/v1/docs/api_docs.go +++ b/pkg/api/v1/docs/api_docs.go @@ -302,7 +302,7 @@ const docTemplateapi = `{ }, { "type": "string", - "description": "Device Code Identifier", + "description": "Device Flow Identifier", "name": "identifier", "in": "query" } @@ -440,6 +440,12 @@ const docTemplateapi = `{ "responses": { "200": { "description": "OK", + "schema": { + "$ref": "#/definitions/models.ServiceKeyLoginResponse" + } + }, + "400": { + "description": "Bad Request", "schema": { "type": "string" } @@ -490,6 +496,32 @@ const docTemplateapi = `{ "type": "string" } } + }, + "models.ServiceKeyLoginResponse": { + "type": "object", + "properties": { + "organization": { + "type": "string" + }, + "resource_id": { + "type": "string" + }, + "resource_type": { + "type": "string" + }, + "service_key_id": { + "type": "string" + }, + "service_session_id": { + "type": "string" + }, + "service_session_secret": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } } } }` diff --git a/pkg/api/v1/docs/api_swagger.json b/pkg/api/v1/docs/api_swagger.json index fb651c8..f00885a 100644 --- a/pkg/api/v1/docs/api_swagger.json +++ b/pkg/api/v1/docs/api_swagger.json @@ -282,7 +282,7 @@ }, { "type": "string", - "description": "Device Code Identifier", + "description": "Device Flow Identifier", "name": "identifier", "in": "query" } @@ -420,6 +420,12 @@ "responses": { "200": { "description": "OK", + "schema": { + "$ref": "#/definitions/models.ServiceKeyLoginResponse" + } + }, + "400": { + "description": "Bad Request", "schema": { "type": "string" } @@ -470,6 +476,32 @@ "type": "string" } } + }, + "models.ServiceKeyLoginResponse": { + "type": "object", + "properties": { + "organization": { + "type": "string" + }, + "resource_id": { + "type": "string" + }, + "resource_type": { + "type": "string" + }, + "service_key_id": { + "type": "string" + }, + "service_session_id": { + "type": "string" + }, + "service_session_secret": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } } } } \ No newline at end of file diff --git a/pkg/api/v1/docs/api_swagger.yaml b/pkg/api/v1/docs/api_swagger.yaml index 1f8a1ce..564fd5d 100644 --- a/pkg/api/v1/docs/api_swagger.yaml +++ b/pkg/api/v1/docs/api_swagger.yaml @@ -19,6 +19,23 @@ definitions: user_code: type: string type: object + models.ServiceKeyLoginResponse: + properties: + organization: + type: string + resource_id: + type: string + resource_type: + type: string + service_key_id: + type: string + service_session_id: + type: string + service_session_secret: + type: string + user_id: + type: string + type: object host: localhost:8080 info: contact: @@ -201,7 +218,7 @@ paths: in: query name: organization type: string - - description: Device Code Identifier + - description: Device Flow Identifier in: query name: identifier type: string @@ -296,6 +313,10 @@ paths: responses: "200": description: OK + schema: + $ref: '#/definitions/models.ServiceKeyLoginResponse' + "400": + description: Bad Request schema: type: string "401": diff --git a/pkg/api/v1/github/github.go b/pkg/api/v1/github/github.go index d83ab47..2f6e700 100644 --- a/pkg/api/v1/github/github.go +++ b/pkg/api/v1/github/github.go @@ -62,7 +62,7 @@ func (a *Github) App() *fiber.App { // @Produce json // @Param next query string false "Next Redirect URL" // @Param organization query string false "Organization" -// @Param identifier query string false "Device Code Identifier" +// @Param identifier query string false "Device Flow Identifier" // @Success 307 // @Header 307 {string} Location "Redirects to Github" // @Failure 401 {string} string @@ -74,7 +74,24 @@ func (a *Github) GithubLogin(ctx *fiber.Ctx) error { return ctx.Status(fiber.StatusUnauthorized).SendString("github provider is not enabled") } - redirect, err := a.options.Github().StartFlow(ctx.Context(), ctx.Query("next", a.options.NextURL()), ctx.Query("organization"), ctx.Query("identifier")) + identifier := ctx.Query("identifier") + if identifier != "" { + if a.options.Device() == nil { + return ctx.Status(fiber.StatusUnauthorized).SendString("device provider is not enabled") + } + + exists, err := a.options.Device().FlowExists(ctx.Context(), identifier) + if err != nil { + a.logger.Error().Err(err).Msg("failed to check if flow exists") + return ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if flow exists") + } + + if !exists { + return ctx.Status(fiber.StatusUnauthorized).SendString("invalid device flow identifier") + } + } + + redirect, err := a.options.Github().StartFlow(ctx.Context(), ctx.Query("next", a.options.NextURL()), ctx.Query("organization"), identifier) if err != nil { a.logger.Error().Err(err).Msg("failed to get redirect") return ctx.Status(fiber.StatusInternalServerError).SendString("failed to get redirect") diff --git a/pkg/api/v1/models/models.go b/pkg/api/v1/models/models.go index c1999e0..a1b667f 100644 --- a/pkg/api/v1/models/models.go +++ b/pkg/api/v1/models/models.go @@ -31,11 +31,11 @@ type DeviceCallbackResponse struct { } type ServiceKeyLoginResponse struct { - ServiceKeySessionID string `json:"service_key_session_id"` - ServiceKeySessionSecret string `json:"service_key_session_secret"` - ServiceKeyID string `json:"service_key_id"` - UserID string `json:"user_id"` - Organization string `json:"organization"` - ResourceType string `json:"resource_type"` - ResourceID string `json:"resource_id"` + ServiceSessionID string `json:"service_session_id"` + ServiceSessionSecret string `json:"service_session_secret"` + ServiceKeyID string `json:"service_key_id"` + UserID string `json:"user_id"` + Organization string `json:"organization"` + ResourceType string `json:"resource_type"` + ResourceID string `json:"resource_id"` } diff --git a/pkg/api/v1/servicekey/servicekey.go b/pkg/api/v1/servicekey/servicekey.go index 6f64e8f..4b5d13f 100644 --- a/pkg/api/v1/servicekey/servicekey.go +++ b/pkg/api/v1/servicekey/servicekey.go @@ -62,7 +62,8 @@ func (a *ServiceKey) App() *fiber.App { // @Accept json // @Produce json // @Param servicekey query string true "Service Key" -// @Success 200 {string} string +// @Success 200 {object} models.ServiceKeyLoginResponse +// @Failure 400 {string} string // @Failure 401 {string} string // @Failure 500 {string} string // @Router /servicekey/login [post] @@ -87,22 +88,18 @@ func (a *ServiceKey) ServiceKeyLogin(ctx *fiber.Ctx) error { keySecret := []byte(keySplit[1]) a.logger.Debug().Msgf("logging in user with service key ID %s", keyID) - sess, secret, err := a.options.Manager().CreateServiceKeySession(ctx, keyID, keySecret) + sess, secret, err := a.options.Manager().CreateServiceSession(ctx, keyID, keySecret) if sess == nil || secret == nil { return err } return ctx.JSON(&models.ServiceKeyLoginResponse{ - ServiceKeySessionID: sess.ID, - ServiceKeySessionSecret: string(secret), - ServiceKeyID: sess.ServiceKeyID, - UserID: sess.UserID, - Organization: sess.Organization, - ResourceType: sess.ResourceType, - ResourceID: sess.ResourceID, + ServiceSessionID: sess.ID, + ServiceSessionSecret: string(secret), + ServiceKeyID: sess.ServiceKeyID, + UserID: sess.UserID, + Organization: sess.Organization, + ResourceType: sess.ResourceType, + ResourceID: sess.ResourceID, }) } - -func (a *ServiceKey) ServiceKeyLogout(ctx *fiber.Ctx) error { - return nil -} diff --git a/pkg/api/v1/v1.go b/pkg/api/v1/v1.go index f763ca6..faea42d 100644 --- a/pkg/api/v1/v1.go +++ b/pkg/api/v1/v1.go @@ -101,7 +101,7 @@ func (v *V1) Logout(ctx *fiber.Ctx) error { return err } - err = v.options.Manager().LogoutServiceKeySession(ctx) + err = v.options.Manager().LogoutServiceSession(ctx) if err != nil { return err } diff --git a/pkg/database/database.go b/pkg/database/database.go index fe63109..f0b8858 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -18,19 +18,12 @@ package database import ( "context" - "github.com/loopholelabs/auth/internal/ent" - "github.com/loopholelabs/auth/internal/ent/deviceflow" - "github.com/loopholelabs/auth/internal/ent/githubflow" - "github.com/loopholelabs/auth/pkg/provider/github" - "github.com/rs/zerolog" - "time" - _ "github.com/lib/pq" + "github.com/loopholelabs/auth/internal/ent" _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" ) -var _ github.Database = (*Database)(nil) - type Database struct { logger *zerolog.Logger client *ent.Client @@ -76,66 +69,3 @@ func (d *Database) Shutdown() error { } return nil } - -func (d *Database) SetGithubFlow(ctx context.Context, state string, verifier string, challenge string, nextURL string, organization string, deviceIdentifier string) error { - d.logger.Debug().Msgf("setting github flow for %s", state) - _, err := d.client.GithubFlow.Create().SetState(state).SetVerifier(verifier).SetChallenge(challenge).SetNextURL(nextURL).SetOrganization(organization).SetDeviceIdentifier(deviceIdentifier).Save(ctx) - return err -} - -func (d *Database) GetGithubFlow(ctx context.Context, state string) (*ent.GithubFlow, error) { - d.logger.Debug().Msgf("getting github flow for %s", state) - return d.client.GithubFlow.Query().Where(githubflow.State(state)).Only(ctx) -} - -func (d *Database) DeleteGithubFlow(ctx context.Context, state string) error { - d.logger.Debug().Msgf("deleting github flow for %s", state) - _, err := d.client.GithubFlow.Delete().Where(githubflow.State(state)).Exec(ctx) - return err -} - -func (d *Database) GCGithubFlow(ctx context.Context, expiry time.Duration) (int, error) { - d.logger.Debug().Msgf("running github flow gc") - return d.client.GithubFlow.Delete().Where(githubflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx) -} - -func (d *Database) SetDeviceFlow(ctx context.Context, identifier string, deviceCode string, userCode string) error { - d.logger.Debug().Msgf("setting device flow for %s (device code %s, user code %s)", identifier, deviceCode, userCode) - _, err := d.client.DeviceFlow.Create().SetIdentifier(identifier).SetDeviceCode(deviceCode).SetUserCode(userCode).Save(ctx) - return err -} - -func (d *Database) GetDeviceFlow(ctx context.Context, deviceCode string) (*ent.DeviceFlow, error) { - d.logger.Debug().Msgf("getting device flow for device code %s", deviceCode) - return d.client.DeviceFlow.Query().Where(deviceflow.DeviceCode(deviceCode)).Only(ctx) -} - -func (d *Database) UpdateDeviceFlow(ctx context.Context, identifier string, session string, expiry time.Time) error { - d.logger.Debug().Msgf("updating device flow for %s (expiry %s)", identifier, expiry) - _, err := d.client.DeviceFlow.Update().Where(deviceflow.Identifier(identifier)).SetSession(session).SetExpiresAt(expiry).Save(ctx) - return err -} - -func (d *Database) GetDeviceFlowUserCode(ctx context.Context, userCode string) (*ent.DeviceFlow, error) { - d.logger.Debug().Msgf("getting device flow for user code %s", userCode) - flow, err := d.client.DeviceFlow.Query().Where(deviceflow.UserCode(userCode)).Only(ctx) - if err != nil { - return nil, err - } - _, err = flow.Update().SetLastPoll(time.Now()).Save(ctx) - if err != nil { - return nil, err - } - return flow, nil -} - -func (d *Database) DeleteDeviceFlow(ctx context.Context, deviceCode string) error { - d.logger.Debug().Msgf("deleting device flow for device code %s", deviceCode) - _, err := d.client.DeviceFlow.Delete().Where(deviceflow.DeviceCode(deviceCode)).Exec(ctx) - return err -} - -func (d *Database) GCDeviceFlow(ctx context.Context, expiry time.Duration) (int, error) { - d.logger.Debug().Msgf("running device flow gc") - return d.client.DeviceFlow.Delete().Where(deviceflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx) -} diff --git a/pkg/database/device.go b/pkg/database/device.go new file mode 100644 index 0000000..f3a6dee --- /dev/null +++ b/pkg/database/device.go @@ -0,0 +1,77 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package database + +import ( + "context" + "github.com/loopholelabs/auth/internal/ent" + "github.com/loopholelabs/auth/internal/ent/deviceflow" + "github.com/loopholelabs/auth/pkg/provider/device" + "time" +) + +var _ device.Database = (*Database)(nil) + +func (d *Database) SetDeviceFlow(ctx context.Context, identifier string, deviceCode string, userCode string) error { + d.logger.Debug().Msgf("setting device flow for %s (device code %s, user code %s)", identifier, deviceCode, userCode) + _, err := d.client.DeviceFlow.Create().SetIdentifier(identifier).SetDeviceCode(deviceCode).SetUserCode(userCode).Save(ctx) + return err +} + +func (d *Database) GetDeviceFlow(ctx context.Context, deviceCode string) (*ent.DeviceFlow, error) { + d.logger.Debug().Msgf("getting device flow for device code %s", deviceCode) + return d.client.DeviceFlow.Query().Where(deviceflow.DeviceCode(deviceCode)).Only(ctx) +} + +func (d *Database) UpdateDeviceFlow(ctx context.Context, identifier string, session string, expiry time.Time) error { + d.logger.Debug().Msgf("updating device flow for %s (expiry %s)", identifier, expiry) + _, err := d.client.DeviceFlow.Update().Where(deviceflow.Identifier(identifier)).SetSession(session).SetExpiresAt(expiry).Save(ctx) + return err +} + +func (d *Database) GetDeviceFlowUserCode(ctx context.Context, userCode string) (*ent.DeviceFlow, error) { + d.logger.Debug().Msgf("getting device flow for user code %s", userCode) + flow, err := d.client.DeviceFlow.Query().Where(deviceflow.UserCode(userCode)).Only(ctx) + if err != nil { + return nil, err + } + _, err = flow.Update().SetLastPoll(time.Now()).Save(ctx) + if err != nil { + return nil, err + } + return flow, nil +} + +func (d *Database) GetDeviceFlowIdentifier(ctx context.Context, identifier string) (*ent.DeviceFlow, error) { + d.logger.Debug().Msgf("getting device flow for identifier %s", identifier) + flow, err := d.client.DeviceFlow.Query().Where(deviceflow.Identifier(identifier)).Only(ctx) + if err != nil { + return nil, err + } + return flow, nil +} + +func (d *Database) DeleteDeviceFlow(ctx context.Context, deviceCode string) error { + d.logger.Debug().Msgf("deleting device flow for device code %s", deviceCode) + _, err := d.client.DeviceFlow.Delete().Where(deviceflow.DeviceCode(deviceCode)).Exec(ctx) + return err +} + +func (d *Database) GCDeviceFlow(ctx context.Context, expiry time.Duration) (int, error) { + d.logger.Debug().Msgf("running device flow gc") + return d.client.DeviceFlow.Delete().Where(deviceflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx) +} diff --git a/pkg/database/github.go b/pkg/database/github.go new file mode 100644 index 0000000..cd36ed9 --- /dev/null +++ b/pkg/database/github.go @@ -0,0 +1,49 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package database + +import ( + "context" + "github.com/loopholelabs/auth/internal/ent" + "github.com/loopholelabs/auth/internal/ent/githubflow" + "github.com/loopholelabs/auth/pkg/provider/github" + "time" +) + +var _ github.Database = (*Database)(nil) + +func (d *Database) SetGithubFlow(ctx context.Context, state string, verifier string, challenge string, nextURL string, organization string, deviceIdentifier string) error { + d.logger.Debug().Msgf("setting github flow for %s", state) + _, err := d.client.GithubFlow.Create().SetState(state).SetVerifier(verifier).SetChallenge(challenge).SetNextURL(nextURL).SetOrganization(organization).SetDeviceIdentifier(deviceIdentifier).Save(ctx) + return err +} + +func (d *Database) GetGithubFlow(ctx context.Context, state string) (*ent.GithubFlow, error) { + d.logger.Debug().Msgf("getting github flow for %s", state) + return d.client.GithubFlow.Query().Where(githubflow.State(state)).Only(ctx) +} + +func (d *Database) DeleteGithubFlow(ctx context.Context, state string) error { + d.logger.Debug().Msgf("deleting github flow for %s", state) + _, err := d.client.GithubFlow.Delete().Where(githubflow.State(state)).Exec(ctx) + return err +} + +func (d *Database) GCGithubFlow(ctx context.Context, expiry time.Duration) (int, error) { + d.logger.Debug().Msgf("running github flow gc") + return d.client.GithubFlow.Delete().Where(githubflow.CreatedAtLT(time.Now().Add(expiry))).Exec(ctx) +} diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index b6bc50c..4f2e65c 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -29,7 +29,7 @@ import ( "github.com/loopholelabs/auth/pkg/claims" "github.com/loopholelabs/auth/pkg/kind" "github.com/loopholelabs/auth/pkg/provider" - "github.com/loopholelabs/auth/pkg/servicekey" + "github.com/loopholelabs/auth/pkg/servicesession" "github.com/loopholelabs/auth/pkg/session" "github.com/loopholelabs/auth/pkg/storage" "github.com/loopholelabs/auth/pkg/utils" @@ -72,8 +72,8 @@ type Manager struct { sessions map[string]struct{} sessionsMu sync.RWMutex - servicekeySessions map[string]*servicekey.Session - servicekeySessionsMu sync.RWMutex + serviceSessions map[string]*servicesession.ServiceSession + serviceSessionsMu sync.RWMutex apikeys map[string]*apikey.APIKey apikeysMu sync.RWMutex @@ -83,15 +83,15 @@ func New(domain string, tls bool, storage storage.Storage, logger *zerolog.Logge l := logger.With().Str("AUTH", "SESSION-MANAGER").Logger() ctx, cancel := context.WithCancel(context.Background()) return &Manager{ - logger: &l, - storage: storage, - domain: domain, - tls: tls, - ctx: ctx, - cancel: cancel, - sessions: make(map[string]struct{}), - servicekeySessions: make(map[string]*servicekey.Session), - apikeys: make(map[string]*apikey.APIKey), + logger: &l, + storage: storage, + domain: domain, + tls: tls, + ctx: ctx, + cancel: cancel, + sessions: make(map[string]struct{}), + serviceSessions: make(map[string]*servicesession.ServiceSession), + apikeys: make(map[string]*apikey.APIKey), } } @@ -99,14 +99,11 @@ func (m *Manager) Start() error { m.logger.Info().Msg("starting manager") m.secretKeyMu.Lock() - secretKeyEvents, err := m.storage.SubscribeToSecretKey(m.ctx) - if err != nil { - m.secretKeyMu.Unlock() - return fmt.Errorf("failed to subscribe to secret key events: %w", err) - } + secretKeyEvents := m.storage.SubscribeToSecretKey(m.ctx) m.wg.Add(1) go m.subscribeToSecretKeyEvents(secretKeyEvents) m.logger.Info().Msg("subscribed to secret key events") + var err error m.secretKey, err = m.storage.GetSecretKey(m.ctx) if err != nil { if errors.Is(err, storage.ErrNotFound) { @@ -127,7 +124,7 @@ func (m *Manager) Start() error { m.logger.Info().Msg("retrieved secret key") m.registrationMu.Lock() - registrationEvents, err := m.storage.SubscribeToRegistration(m.ctx) + registrationEvents := m.storage.SubscribeToRegistration(m.ctx) if err != nil { m.registrationMu.Unlock() return fmt.Errorf("failed to subscribe to registration events: %w", err) @@ -143,47 +140,47 @@ func (m *Manager) Start() error { m.logger.Info().Msg("retrieved registration") m.sessionsMu.Lock() - sessionEvents, err := m.storage.SubscribeToSessionIDs(m.ctx) + sessionEvents := m.storage.SubscribeToSessions(m.ctx) if err != nil { m.sessionsMu.Unlock() return fmt.Errorf("failed to subscribe to session events: %w", err) } m.wg.Add(1) - go m.subscribeToSessionIDEvents(sessionEvents) - m.logger.Info().Msg("subscribed to session ID events") - sessions, err := m.storage.ListSessionIDs(m.ctx) + go m.subscribeToSessionEvents(sessionEvents) + m.logger.Info().Msg("subscribed to session events") + sessions, err := m.storage.ListSessions(m.ctx) if err != nil { m.sessionsMu.Unlock() return fmt.Errorf("failed to list session IDs: %w", err) } for _, sess := range sessions { - m.sessions[sess] = struct{}{} + m.sessions[sess.ID] = struct{}{} } m.sessionsMu.Unlock() - m.logger.Info().Msg("retrieved session IDs") + m.logger.Info().Msg("retrieved sessions") - m.servicekeySessionsMu.Lock() - servicekeySessionEvents, err := m.storage.SubscribeToServiceKeySessions(m.ctx) + m.serviceSessionsMu.Lock() + serviceSessionEvents := m.storage.SubscribeToServiceSessions(m.ctx) if err != nil { - m.servicekeySessionsMu.Unlock() - return fmt.Errorf("failed to subscribe to service key session events: %w", err) + m.serviceSessionsMu.Unlock() + return fmt.Errorf("failed to subscribe to service session events: %w", err) } m.wg.Add(1) - go m.subscribeToServiceKeySessionEvents(servicekeySessionEvents) - m.logger.Info().Msg("subscribed to service key session events") - servicekeySessionIDs, err := m.storage.ListServiceKeySessions(m.ctx) + go m.subscribeToServiceSessionEvents(serviceSessionEvents) + m.logger.Info().Msg("subscribed to service session events") + serviceSessions, err := m.storage.ListServiceSessions(m.ctx) if err != nil { - m.servicekeySessionsMu.Unlock() - return fmt.Errorf("failed to list service key session IDs: %w", err) + m.serviceSessionsMu.Unlock() + return fmt.Errorf("failed to list service sessions: %w", err) } - for _, sess := range servicekeySessionIDs { - m.servicekeySessions[sess.ID] = sess + for _, sess := range serviceSessions { + m.serviceSessions[sess.ID] = sess } - m.servicekeySessionsMu.Unlock() - m.logger.Info().Msg("retrieved service key sessions") + m.serviceSessionsMu.Unlock() + m.logger.Info().Msg("retrieved service sessions") m.apikeysMu.Lock() - apikeyEvents, err := m.storage.SubscribeToAPIKeys(m.ctx) + apikeyEvents := m.storage.SubscribeToAPIKeys(m.ctx) if err != nil { m.apikeysMu.Unlock() return fmt.Errorf("failed to subscribe to api key events: %w", err) @@ -301,130 +298,18 @@ func (m *Manager) CreateSession(ctx *fiber.Ctx, kind kind.Kind, provider provide return m.GenerateCookie(encrypted, sess.Expiry), nil } -func (m *Manager) GetSession(ctx *fiber.Ctx, cookie string) (*session.Session, error) { - m.secretKeyMu.RLock() - secretKey := m.secretKey - oldSecretKey := m.oldSecretKey - m.secretKeyMu.RUnlock() - - oldSecretKeyUsed := false - decrypted, err := aes.Decrypt(secretKey, CookieKey, cookie) - if err != nil { - if errors.Is(err, aes.ErrInvalidContent) { - if oldSecretKey != nil { - decrypted, err = aes.Decrypt(oldSecretKey, CookieKey, cookie) - if err != nil { - if errors.Is(err, aes.ErrInvalidContent) { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid session cookie") - } - m.logger.Error().Err(err).Msg("failed to decrypt session with old secret key") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to decrypt session") - } - oldSecretKeyUsed = true - } else { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid session cookie") - } - } else { - m.logger.Error().Err(err).Msg("failed to decrypt session") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to decrypt session") - } - } - - sess := new(session.Session) - err = json.Unmarshal(decrypted, sess) - if err != nil { - m.logger.Error().Err(err).Msg("failed to unmarshal session") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to unmarshal session") - } - - if sess.Expired() { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("session expired") - } - - m.sessionsMu.RLock() - _, exists := m.sessions[sess.ID] - m.sessionsMu.RUnlock() - if !exists { - exists, err = m.storage.SessionIDExists(ctx.Context(), sess.ID) - if err != nil { - m.logger.Error().Err(err).Msg("failed to check if session exists") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if session exists") - } - if !exists { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("session does not exist") - } - } - - if oldSecretKeyUsed || sess.CloseToExpiry() { - sess.Refresh() - data, err := json.Marshal(sess) - if err != nil { - m.logger.Error().Err(err).Msg("failed to marshal refreshed session") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to marshal session") - } - - encrypted, err := aes.Encrypt(secretKey, CookieKey, data) - if err != nil { - m.logger.Error().Err(err).Msg("failed to encrypt refreshed session") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to encrypt session") - } - - err = m.storage.SetSession(ctx.Context(), sess) - if err != nil { - m.logger.Error().Err(err).Msg("failed to set refreshed session") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to set session") - } - - ctx.Cookie(m.GenerateCookie(encrypted, sess.Expiry)) - } - - return sess, nil -} - -func (m *Manager) GetAPIKey(ctx *fiber.Ctx, keyID string, keySecret []byte) (*apikey.APIKey, error) { - m.apikeysMu.RLock() - key, ok := m.apikeys[keyID] - m.apikeysMu.RUnlock() - if !ok { - var err error - key, err = m.storage.GetAPIKey(ctx.Context(), keyID) - if err != nil { - if errors.Is(err, storage.ErrNotFound) { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("api key does not exist") - } - m.logger.Error().Err(err).Msg("failed to check if api key exists") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if api key exists") - } - } - - if bcrypt.CompareHashAndPassword(keySecret, key.Hash) != nil { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid api key") - } - - return key, nil -} - -func (m *Manager) GetServiceKey(ctx *fiber.Ctx, keyID string, keySecret []byte) (*servicekey.ServiceKey, error) { - key, err := m.storage.GetServiceKey(ctx.Context(), keyID) +func (m *Manager) CreateServiceSession(ctx *fiber.Ctx, keyID string, keySecret []byte) (*servicesession.ServiceSession, []byte, error) { + serviceKey, err := m.storage.GetServiceKey(ctx.Context(), keyID) if err != nil { if errors.Is(err, storage.ErrNotFound) { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("service key does not exist") + return nil, nil, ctx.Status(fiber.StatusUnauthorized).SendString("service key does not exist") } m.logger.Error().Err(err).Msg("failed to check if service key exists") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if service key exists") - } - - if bcrypt.CompareHashAndPassword(keySecret, key.Hash) != nil { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid service key") + return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if service key exists") } - return key, nil -} - -func (m *Manager) CreateServiceKeySession(ctx *fiber.Ctx, keyID string, keySecret []byte) (*servicekey.Session, []byte, error) { - serviceKey, err := m.GetServiceKey(ctx, keyID, keySecret) - if err != nil { - return nil, nil, err + if bcrypt.CompareHashAndPassword(keySecret, serviceKey.Hash) != nil { + return nil, nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid service key") } if !serviceKey.Expires.IsZero() && time.Now().After(serviceKey.Expires) { @@ -435,56 +320,33 @@ func (m *Manager) CreateServiceKeySession(ctx *fiber.Ctx, keyID string, keySecre return nil, nil, ctx.Status(fiber.StatusUnauthorized).SendString("service key has reached its maximum uses") } - err = m.storage.IncrementServiceKeyNumUsed(ctx.Context(), serviceKey.ID) + err = m.storage.IncrementServiceKeyNumUsed(ctx.Context(), serviceKey.ID, 1) if err != nil { m.logger.Error().Err(err).Msg("failed to increment service key num used") return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to increment service key num used") } - sess, secret, err := servicekey.NewSession(serviceKey) + sess, secret, err := servicesession.New(serviceKey) if err != nil { - m.logger.Error().Err(err).Msg("failed to create service key session") - return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to create service key session") + m.logger.Error().Err(err).Msg("failed to create service session") + return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to create service session") } - err = m.storage.SetServiceKeySession(ctx.Context(), sess) + err = m.storage.SetServiceSession(ctx.Context(), sess.ID, sess.Hash, sess.ServiceKeyID) if err != nil { - m.logger.Error().Err(err).Msg("failed to set service key session") - return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to set service key session") + m.logger.Error().Err(err).Msg("failed to set service session") + return nil, nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to set service session") } - m.logger.Debug().Msgf("created service key session %s for user %s (org '%s')", sess.ID, sess.UserID, sess.Organization) + m.logger.Debug().Msgf("created service session %s for user %s (org '%s')", sess.ID, sess.UserID, sess.Organization) return sess, secret, nil } -func (m *Manager) GetServiceKeySession(ctx *fiber.Ctx, sessionID string, sessionSecret []byte) (*servicekey.Session, error) { - m.servicekeySessionsMu.RLock() - sess, ok := m.servicekeySessions[sessionID] - m.servicekeySessionsMu.RUnlock() - if !ok { - var err error - sess, err = m.storage.GetServiceKeySession(ctx.Context(), sessionID) - if err != nil { - if errors.Is(err, storage.ErrNotFound) { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("service key session does not exist") - } - m.logger.Error().Err(err).Msg("failed to check if service key session exists") - return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if service key session exists") - } - } - - if bcrypt.CompareHashAndPassword(sessionSecret, sess.Hash) != nil { - return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid service key session") - } - - return sess, nil -} - func (m *Manager) Validate(ctx *fiber.Ctx) error { cookie := ctx.Cookies(CookieKeyString) if cookie != "" { - sess, err := m.GetSession(ctx, cookie) + sess, err := m.getSession(ctx, cookie) if sess == nil { return err } @@ -511,7 +373,7 @@ func (m *Manager) Validate(ctx *fiber.Ctx) error { keySecret := keySplit[1] if bytes.HasPrefix(authHeader, auth.APIKeyPrefix) { - key, err := m.GetAPIKey(ctx, keyID, keySecret) + key, err := m.getAPIKey(ctx, keyID, keySecret) if key == nil { return err } @@ -523,14 +385,14 @@ func (m *Manager) Validate(ctx *fiber.Ctx) error { return ctx.Next() } - if bytes.HasPrefix(authHeader, auth.ServiceKeySessionPrefix) { - key, err := m.GetServiceKeySession(ctx, keyID, keySecret) + if bytes.HasPrefix(authHeader, auth.ServiceSessionPrefix) { + key, err := m.getServiceSession(ctx, keyID, keySecret) if key == nil { return err } - ctx.Locals(auth.KindContextKey, auth.KindServiceKey) - ctx.Locals(auth.ServiceKeySessionContextKey, key) + ctx.Locals(auth.KindContextKey, auth.KindServiceSession) + ctx.Locals(auth.ServiceSessionContextKey, key) ctx.Locals(auth.UserContextKey, key.UserID) ctx.Locals(auth.OrganizationContextKey, key.Organization) return ctx.Next() @@ -554,7 +416,7 @@ func (m *Manager) LogoutSession(ctx *fiber.Ctx) error { return nil } -func (m *Manager) LogoutServiceKeySession(ctx *fiber.Ctx) error { +func (m *Manager) LogoutServiceSession(ctx *fiber.Ctx) error { authHeader := ctx.Request().Header.PeekBytes(AuthorizationHeader) if len(authHeader) > len(BearerHeader) { if !bytes.Equal(authHeader[:len(BearerHeader)], BearerHeader) { @@ -562,7 +424,7 @@ func (m *Manager) LogoutServiceKeySession(ctx *fiber.Ctx) error { } authHeader = authHeader[len(BearerHeader):] - if !bytes.HasPrefix(authHeader, auth.ServiceKeySessionPrefix) { + if !bytes.HasPrefix(authHeader, auth.ServiceSessionPrefix) { return nil } @@ -574,20 +436,146 @@ func (m *Manager) LogoutServiceKeySession(ctx *fiber.Ctx) error { keyID := string(keySplit[0]) keySecret := keySplit[1] - sess, err := m.GetServiceKeySession(ctx, keyID, keySecret) + sess, err := m.getServiceSession(ctx, keyID, keySecret) if sess == nil { return err } - err = m.storage.DeleteServiceKeySession(ctx.Context(), sess.ID) + err = m.storage.DeleteServiceSession(ctx.Context(), sess.ID) if err != nil { - m.logger.Error().Err(err).Msg("failed to delete service key session") - return ctx.Status(fiber.StatusInternalServerError).SendString("failed to delete service key session") + m.logger.Error().Err(err).Msg("failed to delete service session") + return ctx.Status(fiber.StatusInternalServerError).SendString("failed to delete service session") } } return nil } +func (m *Manager) getSession(ctx *fiber.Ctx, cookie string) (*session.Session, error) { + m.secretKeyMu.RLock() + secretKey := m.secretKey + oldSecretKey := m.oldSecretKey + m.secretKeyMu.RUnlock() + + oldSecretKeyUsed := false + decrypted, err := aes.Decrypt(secretKey, CookieKey, cookie) + if err != nil { + if errors.Is(err, aes.ErrInvalidContent) { + if oldSecretKey != nil { + decrypted, err = aes.Decrypt(oldSecretKey, CookieKey, cookie) + if err != nil { + if errors.Is(err, aes.ErrInvalidContent) { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid session cookie") + } + m.logger.Error().Err(err).Msg("failed to decrypt session with old secret key") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to decrypt session") + } + oldSecretKeyUsed = true + } else { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid session cookie") + } + } else { + m.logger.Error().Err(err).Msg("failed to decrypt session") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to decrypt session") + } + } + + sess := new(session.Session) + err = json.Unmarshal(decrypted, sess) + if err != nil { + m.logger.Error().Err(err).Msg("failed to unmarshal session") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to unmarshal session") + } + + if sess.Expired() { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("session expired") + } + + m.sessionsMu.RLock() + _, exists := m.sessions[sess.ID] + m.sessionsMu.RUnlock() + if !exists { + _, err = m.storage.GetSession(ctx.Context(), sess.ID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("session does not exist") + } + m.logger.Error().Err(err).Msg("failed to check if session exists") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if session exists") + } + } + + if oldSecretKeyUsed || sess.CloseToExpiry() { + sess.Refresh() + data, err := json.Marshal(sess) + if err != nil { + m.logger.Error().Err(err).Msg("failed to marshal refreshed session") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to marshal session") + } + + encrypted, err := aes.Encrypt(secretKey, CookieKey, data) + if err != nil { + m.logger.Error().Err(err).Msg("failed to encrypt refreshed session") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to encrypt session") + } + + err = m.storage.UpdateSessionExpiry(ctx.Context(), sess.ID, sess.Expiry) + if err != nil { + m.logger.Error().Err(err).Msg("failed to update session expiry") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to update session expiry") + } + + ctx.Cookie(m.GenerateCookie(encrypted, sess.Expiry)) + } + + return sess, nil +} + +func (m *Manager) getAPIKey(ctx *fiber.Ctx, keyID string, keySecret []byte) (*apikey.APIKey, error) { + m.apikeysMu.RLock() + key, ok := m.apikeys[keyID] + m.apikeysMu.RUnlock() + if !ok { + var err error + key, err = m.storage.GetAPIKey(ctx.Context(), keyID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("api key does not exist") + } + m.logger.Error().Err(err).Msg("failed to check if api key exists") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if api key exists") + } + } + + if bcrypt.CompareHashAndPassword(keySecret, key.Hash) != nil { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid api key") + } + + return key, nil +} + +func (m *Manager) getServiceSession(ctx *fiber.Ctx, sessionID string, sessionSecret []byte) (*servicesession.ServiceSession, error) { + m.serviceSessionsMu.RLock() + sess, ok := m.serviceSessions[sessionID] + m.serviceSessionsMu.RUnlock() + if !ok { + var err error + sess, err = m.storage.GetServiceSession(ctx.Context(), sessionID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("service session does not exist") + } + m.logger.Error().Err(err).Msg("failed to check if service session exists") + return nil, ctx.Status(fiber.StatusInternalServerError).SendString("failed to check if service session exists") + } + } + + if bcrypt.CompareHashAndPassword(sessionSecret, sess.Hash) != nil { + return nil, ctx.Status(fiber.StatusUnauthorized).SendString("invalid service session") + } + + return sess, nil +} + func (m *Manager) subscribeToSecretKeyEvents(events <-chan *storage.SecretKeyEvent) { defer m.wg.Done() for { @@ -621,7 +609,7 @@ func (m *Manager) subscribeToRegistrationEvents(events <-chan *storage.Registrat } } -func (m *Manager) subscribeToSessionIDEvents(events <-chan *storage.SessionEvent) { +func (m *Manager) subscribeToSessionEvents(events <-chan *storage.SessionEvent) { defer m.wg.Done() for { select { @@ -630,14 +618,14 @@ func (m *Manager) subscribeToSessionIDEvents(events <-chan *storage.SessionEvent return case event := <-events: if event.Deleted { - m.logger.Debug().Msgf("session %s deleted", event.SessionID) + m.logger.Debug().Msgf("session %s deleted", event.ID) m.sessionsMu.Lock() - delete(m.sessions, event.SessionID) + delete(m.sessions, event.ID) m.sessionsMu.Unlock() } else { - m.logger.Debug().Msgf("session %s created", event.SessionID) + m.logger.Debug().Msgf("session %s created", event.ID) m.sessionsMu.Lock() - m.sessions[event.SessionID] = struct{}{} + m.sessions[event.ID] = struct{}{} m.sessionsMu.Unlock() } } @@ -653,17 +641,17 @@ func (m *Manager) subscribeToAPIKeyEvents(events <-chan *storage.APIKeyEvent) { return case event := <-events: if event.Deleted { - m.logger.Debug().Msgf("api key %s deleted", event.APIKeyID) + m.logger.Debug().Msgf("api key %s deleted", event.ID) m.apikeysMu.Lock() - delete(m.apikeys, event.APIKeyID) + delete(m.apikeys, event.ID) m.apikeysMu.Unlock() } else { - m.logger.Debug().Msgf("api key %s created or updated", event.APIKeyID) + m.logger.Debug().Msgf("api key %s created or updated", event.ID) if event.APIKey == nil { - m.logger.Error().Msgf("api key in create or update event for api key ID %s is nil", event.APIKeyID) + m.logger.Error().Msgf("api key in create or update event for api key ID %s is nil", event.ID) } else { m.apikeysMu.Lock() - m.apikeys[event.APIKeyID] = event.APIKey + m.apikeys[event.ID] = event.APIKey m.apikeysMu.Unlock() } } @@ -671,27 +659,27 @@ func (m *Manager) subscribeToAPIKeyEvents(events <-chan *storage.APIKeyEvent) { } } -func (m *Manager) subscribeToServiceKeySessionEvents(events <-chan *storage.ServiceKeySessionEvent) { +func (m *Manager) subscribeToServiceSessionEvents(events <-chan *storage.ServiceSessionEvent) { defer m.wg.Done() for { select { case <-m.ctx.Done(): - m.logger.Info().Msg("service key session event subscription stopped") + m.logger.Info().Msg("service session event subscription stopped") return case event := <-events: if event.Deleted { - m.logger.Debug().Msgf("service key session %s deleted", event.ServiceKeySessionID) - m.servicekeySessionsMu.Lock() - delete(m.servicekeySessions, event.ServiceKeySessionID) - m.servicekeySessionsMu.Unlock() + m.logger.Debug().Msgf("service session %s deleted", event.ID) + m.serviceSessionsMu.Lock() + delete(m.serviceSessions, event.ID) + m.serviceSessionsMu.Unlock() } else { - m.logger.Debug().Msgf("service key session %s created or updated", event.ServiceKeySessionID) - if event.ServiceKeySession == nil { - m.logger.Error().Msgf("service key session in create or update event for service key session ID %s is nil", event.ServiceKeySessionID) + m.logger.Debug().Msgf("service session %s created or updated", event.ID) + if event.ServiceSession == nil { + m.logger.Error().Msgf("service session in create or update event for service session ID %s is nil", event.ID) } else { - m.servicekeySessionsMu.Lock() - m.servicekeySessions[event.ServiceKeySessionID] = event.ServiceKeySession - m.servicekeySessionsMu.Unlock() + m.serviceSessionsMu.Lock() + m.serviceSessions[event.ID] = event.ServiceSession + m.serviceSessionsMu.Unlock() } } } diff --git a/pkg/provider/device/database.go b/pkg/provider/device/database.go index 69cc972..15b958c 100644 --- a/pkg/provider/device/database.go +++ b/pkg/provider/device/database.go @@ -27,6 +27,7 @@ type Database interface { GetDeviceFlow(ctx context.Context, deviceCode string) (*ent.DeviceFlow, error) UpdateDeviceFlow(ctx context.Context, identifier string, session string, expiry time.Time) error GetDeviceFlowUserCode(ctx context.Context, userCode string) (*ent.DeviceFlow, error) + GetDeviceFlowIdentifier(ctx context.Context, identifier string) (*ent.DeviceFlow, error) DeleteDeviceFlow(ctx context.Context, deviceCode string) error GCDeviceFlow(ctx context.Context, expiry time.Duration) (int, error) } diff --git a/pkg/provider/device/device.go b/pkg/provider/device/device.go index 395d5c7..e61c461 100644 --- a/pkg/provider/device/device.go +++ b/pkg/provider/device/device.go @@ -19,6 +19,7 @@ package device import ( "context" "github.com/google/uuid" + "github.com/loopholelabs/auth/internal/ent" "github.com/loopholelabs/auth/pkg/provider" "github.com/loopholelabs/auth/pkg/utils" "github.com/rs/zerolog" @@ -93,6 +94,18 @@ func (g *Device) ValidateFlow(ctx context.Context, deviceCode string) (string, e return flow.Identifier, nil } +func (g *Device) FlowExists(ctx context.Context, identifier string) (bool, error) { + _, err := g.database.GetDeviceFlowIdentifier(ctx, identifier) + if err != nil { + if ent.IsNotFound(err) { + return false, nil + } + return false, err + } + + return true, nil +} + func (g *Device) PollFlow(ctx context.Context, userCode string) (string, time.Time, time.Time, error) { flow, err := g.database.GetDeviceFlowUserCode(ctx, userCode) if err != nil { diff --git a/pkg/provider/github/github.go b/pkg/provider/github/github.go index 3fda7b7..07c7eb3 100644 --- a/pkg/provider/github/github.go +++ b/pkg/provider/github/github.go @@ -109,7 +109,7 @@ func (g *Github) StartFlow(ctx context.Context, nextURL string, organization str challenge := pkce.CodeChallengeS256(verifier) state := uuid.New().String() - g.logger.Debug().Msgf("starting flow for state %s", state) + g.logger.Debug().Msgf("starting flow for state %s with org '%s' and device identifier '%s'", state, organization, deviceIdentifier) err := g.database.SetGithubFlow(ctx, state, verifier, challenge, nextURL, organization, deviceIdentifier) if err != nil { return "", err diff --git a/pkg/servicekey/session.go b/pkg/servicesession/servicesession.go similarity index 71% rename from pkg/servicekey/session.go rename to pkg/servicesession/servicesession.go index 752b5b6..2edb173 100644 --- a/pkg/servicekey/session.go +++ b/pkg/servicesession/servicesession.go @@ -14,23 +14,24 @@ limitations under the License. */ -package servicekey +package servicesession import ( "github.com/google/uuid" "github.com/loopholelabs/auth" + "github.com/loopholelabs/auth/pkg/servicekey" "golang.org/x/crypto/bcrypt" ) -// Session represents a user's authenticated service key session -type Session struct { - // ID is the Service Key's unique identifier +// ServiceSession represents a user's authenticated service key session +type ServiceSession struct { + // ID is the service session's unique identifier ID string `json:"id"` - // Hash is the hashed secret of the Service Key session + // Hash is the hashed secret of the service session Hash []byte `json:"hash"` - // ServiceKeyID is the ID of the Service Key that the session is associated with + // ServiceKeyID is the ID of the Service Key that the service session is associated with ServiceKeyID string `json:"service_key_id"` // UserID is the user's unique identifier @@ -46,15 +47,15 @@ type Session struct { ResourceID string `json:"resource_id"` } -// NewSession returns a new session for a user with the given service key -func NewSession(servicekey *ServiceKey) (*Session, []byte, error) { - id := uuid.New().String() - secret := []byte(auth.ServiceKeySessionPrefixString + uuid.New().String()) +// New returns a new service session for a user with the given service key +func New(servicekey *servicekey.ServiceKey) (*ServiceSession, []byte, error) { + id := auth.ServiceSessionPrefixString + uuid.New().String() + secret := []byte(uuid.New().String()) hash, err := bcrypt.GenerateFromPassword(secret, bcrypt.DefaultCost) if err != nil { return nil, nil, err } - return &Session{ + return &ServiceSession{ ID: id, Hash: hash, ServiceKeyID: servicekey.ID, diff --git a/pkg/storage/apikey.go b/pkg/storage/apikey.go new file mode 100644 index 0000000..a3517eb --- /dev/null +++ b/pkg/storage/apikey.go @@ -0,0 +1,54 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "github.com/loopholelabs/auth/pkg/apikey" +) + +// APIKeyEvent is the event that is emitted when an API key is created, updated, or deleted +type APIKeyEvent struct { + // ID is the API Key Identifier + ID string + + // Deleted indicates whether the API Key was deleted + Deleted bool + + // APIKey is the API Key that was created or updated. + // If the API Key was deleted, this will be nil + APIKey *apikey.APIKey +} + +// APIKey is the interface for storage of API Keys. +type APIKey interface { + // GetAPIKey returns the API key for the given id. If + // there is an error while getting the API key, an error is returned. + // If there is no error, the API key is returned. If the API key does not + // exist, ErrNotFound is returned. + GetAPIKey(ctx context.Context, id string) (*apikey.APIKey, error) + + // ListAPIKeys returns a list of all API Keys. If there is an error while + // listing the API keys, an error is returned. If there is no error, the list + // of API keys is returned. + ListAPIKeys(ctx context.Context) ([]*apikey.APIKey, error) + + // SubscribeToAPIKeys subscribes to API key events. When an API key is created, + // updated, or deleted, the event is emitted on the given channel. Cancelling + // the provided context will unsubscribe from API key events. + SubscribeToAPIKeys(ctx context.Context) <-chan *APIKeyEvent +} diff --git a/pkg/storage/registration.go b/pkg/storage/registration.go new file mode 100644 index 0000000..6a53b71 --- /dev/null +++ b/pkg/storage/registration.go @@ -0,0 +1,42 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import "context" + +// RegistrationEvent is the event that is emitted when registration is enabled or disabled +type RegistrationEvent struct { + // Enabled indicates whether registration is enabled + Enabled bool +} + +// Registration is the interface for storage of registration settings. +type Registration interface { + // SetRegistration sets whether registration is enabled. If there is an error + // while setting the registration status, an error is returned. + SetRegistration(ctx context.Context, enabled bool) error + + // GetRegistration returns whether registration is enabled. If there is an error + // while getting the registration status, an error is returned. If there is no + // error, the boolean indicates whether registration is enabled. + GetRegistration(ctx context.Context) (bool, error) + + // SubscribeToRegistration subscribes to registration events. When registration + // is enabled or disabled, the event is emitted on the given channel. Cancelling + // the provided context will unsubscribe from registration events. + SubscribeToRegistration(ctx context.Context) <-chan *RegistrationEvent +} diff --git a/pkg/storage/secretkey.go b/pkg/storage/secretkey.go new file mode 100644 index 0000000..87cc282 --- /dev/null +++ b/pkg/storage/secretkey.go @@ -0,0 +1,54 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "errors" +) + +var ( + ErrInvalidSecretKey = errors.New("invalid secret key") +) + +// SecretKeyEvent is the event that is emitted when a secret key is rotated +type SecretKeyEvent struct { + // SecretKey is the new secret key + SecretKey []byte +} + +// SecretKey is the interface for managing the secret keys used to sign and verify sessions. +type SecretKey interface { + // SetSecretKey sets the current secret key. If there is an error while + // setting the secret key, an error is returned. + // If there is no error, the secret key is returned. + // The secret key should be exactly 32 bytes long. If it + // is not, ErrInvalidSecretKey is returned. + SetSecretKey(ctx context.Context, secretKey []byte) error + + // GetSecretKey returns the current secret key. If there is an error while + // getting the secret key, an error is returned. If the secret key does not + // exist, ErrNotFound is returned. If there is no error, the secret + // key is returned. The secret key should be exactly 32 bytes long. If it + // is not, ErrInvalidSecretKey is returned. + GetSecretKey(ctx context.Context) ([]byte, error) + + // SubscribeToSecretKey subscribes to secret key events. When the secret key is + // rotated, the event is emitted on the given channel. Cancelling the provided + // context will unsubscribe from secret key events. + SubscribeToSecretKey(ctx context.Context) <-chan *SecretKeyEvent +} diff --git a/pkg/storage/servicekey.go b/pkg/storage/servicekey.go new file mode 100644 index 0000000..0e86e00 --- /dev/null +++ b/pkg/storage/servicekey.go @@ -0,0 +1,34 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "github.com/loopholelabs/auth/pkg/servicekey" +) + +type ServiceKey interface { + // GetServiceKey returns the service key for the given ID. If there is an error + // while getting the service key, an error is returned. If there is no error, the service key + // is returned. If the service key does not exist, ErrNotFound is returned. + GetServiceKey(ctx context.Context, id string) (*servicekey.ServiceKey, error) + + // IncrementServiceKeyNumUsed increments the number of times the service key has been used by increment. + // If there is an error while incrementing the number of times the service key has been used, + // an error is returned. If the service key does not exist, ErrNotFound is returned. + IncrementServiceKeyNumUsed(ctx context.Context, id string, increment int64) error +} diff --git a/pkg/storage/servicesession.go b/pkg/storage/servicesession.go new file mode 100644 index 0000000..f7abaa4 --- /dev/null +++ b/pkg/storage/servicesession.go @@ -0,0 +1,66 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "github.com/loopholelabs/auth/pkg/servicesession" +) + +// ServiceSessionEvent is the event that is triggered when a service session is created, updated, or deleted +type ServiceSessionEvent struct { + // ID is the service session's unique identifier + ID string + + // Deleted indicates whether the service session was deleted + Deleted bool + + // ServiceSession is the service session that was created or updated. + // If the service session was deleted, this will be nil + ServiceSession *servicesession.ServiceSession +} + +type ServiceSession interface { + // SetServiceSession sets the service session for the given serviceSession.ID. If + // there is an error while setting the service session, an error is returned. + // If the user or organization does not exist, ErrNotFound is returned. + // If the organization associated with the service session is not empty, + // service the session is associated with the organization. If the service session is + // associated with an organization and that organization is deleted, the service session + // should also be deleted. If the service session already exists, it ErrAlreadyExists is returned. + SetServiceSession(ctx context.Context, id string, hash []byte, serviceKeyID string) error + + // GetServiceSession gets the service session for the given id. If there is an error + // while getting the service session, an error is returned. If the service session does not + // exist, ErrNotFound is returned. + GetServiceSession(ctx context.Context, id string) (*servicesession.ServiceSession, error) + + // ListServiceSessions returns a list of all service sessions. If there is an error while + // listing the service sessions, an error is returned. + // If there is no error, the list of service sessions is returned. + ListServiceSessions(ctx context.Context) ([]*servicesession.ServiceSession, error) + + // DeleteServiceSession deletes the service session for the given id. If + // there is an error while deleting the service session, an error is returned. + // ErrNotFound is returned if the service session does not exist. + DeleteServiceSession(ctx context.Context, id string) error + + // SubscribeToServiceSessions subscribes to service session events. When a service session is created, + // updated, or deleted, the event is emitted on the given channel. Cancelling + // the provided context will unsubscribe from service session events. + SubscribeToServiceSessions(ctx context.Context) <-chan *ServiceSessionEvent +} diff --git a/pkg/storage/session.go b/pkg/storage/session.go new file mode 100644 index 0000000..58d0f4b --- /dev/null +++ b/pkg/storage/session.go @@ -0,0 +1,73 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "github.com/loopholelabs/auth/pkg/session" + "time" +) + +// SessionEvent is the event that is triggered when a session is created, updated, or deleted +type SessionEvent struct { + // ID is the session's unique identifier + ID string + + // Deleted indicates whether the session was deleted + Deleted bool + + // Session is the session that was created or updated. + // If the session was deleted, this will be nil + Session *session.Session +} + +// Session is the interface for storage of sessions. +type Session interface { + // SetSession sets the session for the given session.ID. If there is an error + // while setting the session, an error is returned. + // If the user or organization does not exist, ErrNotFound is returned. + // If the organization associated with the session is not empty, the session is + // associated with the organization. If the session is associated with an organization + // and that organization is deleted, the session should also be deleted. If the session + // already exists, it ErrAlreadyExists is returned. + SetSession(ctx context.Context, session *session.Session) error + + // GetSession gets the session for the given id. If there is an error + // while getting the session, an error is returned. If the session does not + // exist, ErrNotFound is returned. + GetSession(ctx context.Context, id string) (*session.Session, error) + + // ListSessions returns a list of all sessions. If there is an error while + // listing the sessions, an error is returned. + // If there is no error, the list of sessions is returned. + ListSessions(ctx context.Context) ([]*session.Session, error) + + // DeleteSession deletes the session for the given id. If + // there is an error while deleting the session, an error is returned. + // ErrNotFound is returned if the session does not exist. + DeleteSession(ctx context.Context, id string) error + + // UpdateSessionExpiry updates the expiry of the session for the given id. If + // there is an error while updating the session, an error is returned. If the + // session does not exist, ErrNotFound is returned. + UpdateSessionExpiry(ctx context.Context, id string, expiry time.Time) error + + // SubscribeToSessions subscribes to session events. When a session is created, + // updated, or deleted, the event is emitted on the given channel. Cancelling + // the provided context will unsubscribe from session events. + SubscribeToSessions(ctx context.Context) <-chan *SessionEvent +} diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index c54779c..23e4e8e 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -17,184 +17,22 @@ package storage import ( - "context" "errors" - "github.com/loopholelabs/auth/pkg/apikey" - "github.com/loopholelabs/auth/pkg/claims" - "github.com/loopholelabs/auth/pkg/servicekey" - "github.com/loopholelabs/auth/pkg/session" ) var ( - ErrNotFound = errors.New("key not found") + ErrNotFound = errors.New("key not found") + ErrAlreadyExists = errors.New("key already exists") ) -// SessionEvent is the event that is triggered when a session is created, updated, or deleted -type SessionEvent struct { - // Session ID is the Session's unique identifier - SessionID string - - // Deleted indicates whether the session was deleted - Deleted bool -} - -// SecretKeyEvent is the event that is emitted when a secret key is rotated -type SecretKeyEvent struct { - // SecretKey is the new secret key - SecretKey []byte -} - -// RegistrationEvent is the event that is emitted when registration is enabled or disabled -type RegistrationEvent struct { - // Enabled indicates whether registration is enabled - Enabled bool -} - -// APIKeyEvent is the event that is emitted when an API key is created, updated, or deleted -type APIKeyEvent struct { - // APIKeyID is the API Key Identifier - APIKeyID string - - // Deleted indicates whether the API Key was deleted - Deleted bool - - // APIKey is the API Key that was created or updated. - // This will be nil if the API Key was deleted. - APIKey *apikey.APIKey -} - -// ServiceKeySessionEvent is the event that is triggered when a service key session is created, updated, or deleted -type ServiceKeySessionEvent struct { - // ServiceKeySessionID is the Service Key Session's unique identifier - ServiceKeySessionID string - - // Deleted indicates whether the session was deleted - Deleted bool - - // ServiceKeySession is the Service Key Session that was created or updated. - // This will be nil if the Service Key Session was deleted. - ServiceKeySession *servicekey.Session -} - // Storage is the interface that must be implemented by the application // using this auth library for authentication and session handling. type Storage interface { - // UserExists verifies whether the given userID exists. If there is an error - // while checking if the user exists, an error is returned, otherwise - // the boolean indicates whether the user exists. An error should not be - // returned if the user does not exist. - UserExists(ctx context.Context, userID string) (bool, error) - // UserOrganizationExists verifies whether the given userID is part of the - // given organization. If there is an error while checking if the user - // exists, an error is returned, otherwise the boolean indicates whether - // the user exists. An error should not be returned if the user does not - // exist or if the user is not part of the organization. - UserOrganizationExists(ctx context.Context, userID string, organization string) (bool, error) - // NewUser creates a new user with the given claims. If the user already - // exists, an error is returned. If the user does not exist, the user is - // created and the claims are set. If there is an error while creating the - // user, an error is returned. - NewUser(ctx context.Context, claims *claims.Claims) error - - // SubscribeToRegistration subscribes to registration events. When registration - // is enabled or disabled, the event is emitted on the given channel. Cancelling - // the provided context will unsubscribe from registration events. - SubscribeToRegistration(ctx context.Context) (<-chan *RegistrationEvent, error) - // GetRegistration returns whether registration is enabled. If there is an error - // while getting the registration status, an error is returned. If there is no - // error, the boolean indicates whether registration is enabled. - GetRegistration(ctx context.Context) (bool, error) - // SetRegistration sets whether registration is enabled. If there is an error - // while setting the registration status, an error is returned. - // If there is no error, the boolean indicates whether registration is enabled. - SetRegistration(ctx context.Context, enabled bool) error - - // SubscribeToSecretKey subscribes to secret key events. When the secret key is - // rotated, the event is emitted on the given channel. Cancelling the provided - // context will unsubscribe from secret key events. - SubscribeToSecretKey(ctx context.Context) (<-chan *SecretKeyEvent, error) - // GetSecretKey returns the current secret key. If there is an error while - // getting the secret key, an error is returned. - // If there is no error, the secret key is returned. - // The secret key should be exactly 32 bytes long. - GetSecretKey(ctx context.Context) ([]byte, error) - // SetSecretKey sets the current secret key. If there is an error while - // setting the secret key, an error is returned. - // If there is no error, the secret key is returned. - // The secret key should be exactly 32 bytes long. - SetSecretKey(ctx context.Context, secretKey []byte) error - - // SubscribeToSessionIDs subscribes to session events. When a session is created, - // updated, or deleted, the event is emitted on the given channel. Cancelling - // the provided context will unsubscribe from session events. - SubscribeToSessionIDs(ctx context.Context) (<-chan *SessionEvent, error) - // ListSessionIDs returns a list of all session IDs. If there is an error while - // listing the session IDs, an error is returned. - // If there is no error, the list of session IDs is returned. - ListSessionIDs(ctx context.Context) ([]string, error) - //SessionIDExists verifies whether the given sessionID exists. If there is an error - // while checking if the sessionID exists, an error is returned, otherwise - // the boolean indicates whether the sessionID exists. An error should not be - // returned if the sessionID does not exist. - SessionIDExists(ctx context.Context, sessionID string) (bool, error) - - // SetSession sets the session for the given session.ID. If there is an error - // while setting the session, an error is returned. If the organization - // associated with the session is not empty, the session is associated with - // the organization. If the organization is empty, the session is associated - // with the user. If the session is associated with an organization and that - // organization is deleted, the session should also be deleted. - SetSession(ctx context.Context, session *session.Session) error - // DeleteSession deletes the session for the given sessionID. If - // there is an error while deleting the session, an error is returned. - // An error is returned if the session does not exist. - DeleteSession(ctx context.Context, sessionID string) error - - // SubscribeToAPIKeys subscribes to API key events. When an API key is created, - // updated, or deleted, the event is emitted on the given channel. Cancelling - // the provided context will unsubscribe from API key events. - SubscribeToAPIKeys(ctx context.Context) (<-chan *APIKeyEvent, error) - // ListAPIKeys returns a list of all API keys. If there is an error while - // listing the API keys, an error is returned. If there is no error, the list - // of API keys is returned. - ListAPIKeys(ctx context.Context) ([]*apikey.APIKey, error) - // GetAPIKey returns the API key for the given API key ID. If - // there is an error while getting the API key, an error is returned. - // If there is no error, the API key is returned. - GetAPIKey(ctx context.Context, id string) (*apikey.APIKey, error) - - // SubscribeToServiceKeySessions subscribes to service key session events. - // When a service key session is created, updated, or deleted, the event is - // emitted on the given channel. Cancelling the provided context will unsubscribe from - // service key session events. - SubscribeToServiceKeySessions(ctx context.Context) (<-chan *ServiceKeySessionEvent, error) - // ListServiceKeySessions returns a list of all service key session IDs. If there is an error while - // listing the service key session IDs, an error is returned. If there is no error, the list - // of service key session IDs is returned. - ListServiceKeySessions(ctx context.Context) ([]*servicekey.Session, error) - // SetServiceKeySession sets the service key session for the given servicekeySession.ID. If - // there is an error while setting the service key session, an error is returned. - // If the organization associated with the service key session is not empty, the service key session is associated with - // the organization. If the organization is empty, the service key session is associated - // with the user. If the service key session is associated with an organization and that - // organization is deleted, the service key session and the service key itself should be deleted. - // If the service key associated with the service key session is deleted, the service key session should be deleted. - SetServiceKeySession(ctx context.Context, servicekeySession *servicekey.Session) error - // GetServiceKeySession returns the service key session for the given servicekeySessionID. If - // there is an error while getting the service key session, an error is returned. - // If there is no error, the service key session is returned. - GetServiceKeySession(ctx context.Context, servicekeySessionID string) (*servicekey.Session, error) - // DeleteServiceKeySession deletes the service key session for the given servicekeySessionID. If - // there is an error while deleting the service key session, an error is returned. - // An error is returned if the service key session does not exist. - DeleteServiceKeySession(ctx context.Context, servicekeySessionID string) error - - // GetServiceKey returns the service key for the given service key ID. If there is an error - // while getting the service key, an error is returned. If there is no error, the service key - // is returned. - GetServiceKey(ctx context.Context, servicekeyID string) (*servicekey.ServiceKey, error) - // IncrementServiceKeyNumUsed increments the number of times the service key has been used. - // If there is an error while incrementing the number of times the service key has been used, - // an error is returned. If the service key does not exist, an error is returned. - IncrementServiceKeyNumUsed(ctx context.Context, servicekeyID string) error + User + Registration + SecretKey + Session + APIKey + ServiceKey + ServiceSession } diff --git a/pkg/storage/user.go b/pkg/storage/user.go new file mode 100644 index 0000000..d9c1573 --- /dev/null +++ b/pkg/storage/user.go @@ -0,0 +1,46 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package storage + +import ( + "context" + "github.com/loopholelabs/auth/pkg/claims" +) + +// User is the interface that must be implemented to store user data. +type User interface { + // UserExists verifies whether the given userID exists. If there is an error + // while checking if the user exists, an error is returned, otherwise + // the boolean indicates whether the user exists. An error should not be + // returned if the user does not exist. + UserExists(ctx context.Context, userID string) (bool, error) + + // UserOrganizationExists verifies whether the given userID is part of the + // given organization. If there is an error while checking if the user is + // part of the organization, an error is returned, otherwise the boolean indicates + // whether the user is part of the organization. An error should not be + // returned if the user is not part of the organization, if the organization + // does not exist, or if the user does not exist - instead, the boolean + // should be false. + UserOrganizationExists(ctx context.Context, userID string, organization string) (bool, error) + + // NewUser creates a new user with the given claims. If the user already + // exists, an error is returned. If the user does not exist, the user is + // created and the claims are set. If there is an error while creating the + // user, an error is returned. + NewUser(ctx context.Context, claims *claims.Claims) error +}