Skip to content

Commit

Permalink
feat: add error codes
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Feb 20, 2024
1 parent 1ea56b6 commit 13a889c
Show file tree
Hide file tree
Showing 51 changed files with 804 additions and 501 deletions.
36 changes: 18 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,

userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
if err != nil {
return nil, badRequestError("user_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID")
}

observability.LogEntrySetField(r, "user_id", userID)

u, err := models.FindUserByID(db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("User not found")
return nil, notFoundError(ErrorCodeUserNotFound, "User not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
Expand All @@ -69,15 +69,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) {
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, badRequestError("factor_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

f, err := models.FindFactorByFactorID(a.db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Factor not found")
return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
}
Expand All @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {

body, err := getBodyBytes(r)
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
return nil, internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
return nil, badRequestError(ErrorCodeBadJSON, "Could not decode admin user params").WithInternalError(err)
}

return &params, nil
Expand All @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {

pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError("Bad Sort Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
}

filter := r.URL.Query().Get("filter")
Expand Down Expand Up @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}

if params.Email == "" && params.Phone == "" {
return unprocessableEntityError("Cannot create a user without either an email or phone")
return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
}

var providers []string
Expand All @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
return internalServerError("Database error checking email").WithInternalError(err)
} else if user != nil {
return unprocessableEntityError(DuplicateEmailMsg)
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
}
providers = append(providers, "email")
}
Expand All @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
return internalServerError("Database error checking phone").WithInternalError(err)
} else if exists {
return unprocessableEntityError("Phone number already registered by another user")
return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user")
}
providers = append(providers, "phone")
}
Expand Down Expand Up @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
params := &adminUserDeleteParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}
if len(body) > 0 {
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err)
}
} else {
params.ShouldSoftDelete = false
Expand Down Expand Up @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
return badRequestError(ErrorCodeBadJSON, "Could not read factor update params: %v", err).WithInternalError(err)
}

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
}
if params.FactorType != "" {
if params.FactorType != models.TOTP {
return badRequestError("Factor Type not valid")
return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid")
}
if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil {
return terr
Expand Down
31 changes: 31 additions & 0 deletions internal/api/apiversions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package api

import (
"time"
)

const APIVersionHeaderName = "X-Supabase-Api-Version"

type APIVersion = time.Time

var (
APIVersionInitial = time.Time{}
APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC)
)

func DetermineClosestAPIVersion(date string) (APIVersion, error) {
if date == "" {
return APIVersionInitial, nil
}

parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC)
if err != nil {
return APIVersionInitial, err
}

if parsed.Compare(APIVersion20240101) >= 0 {
return APIVersion20240101, nil
}

return APIVersionInitial, nil
}
29 changes: 29 additions & 0 deletions internal/api/apiversions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package api

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDetermineClosestAPIVersion(t *testing.T) {
version, err := DetermineClosestAPIVersion("")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("Not a date")
require.Error(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2023-12-31")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2024-01-01")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)

version, err = DetermineClosestAPIVersion("2024-01-02")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)
}
4 changes: 2 additions & 2 deletions internal/api/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
// aud := a.requestAud(ctx, r)
pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
}

var col []string
Expand All @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
qparts := strings.SplitN(q, ":", 2)
col, exists = filterColumnMap[qparts[0]]
if !exists || len(qparts) < 2 {
return badRequestError("Invalid query scope: %s", q)
return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q)
}
qval = qparts[1]
}
Expand Down
20 changes: 10 additions & 10 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
claims := getClaims(ctx)
if claims == nil {
fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token")
return nil, unauthorizedError("Invalid token")
return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token")
}

adminRoles := a.config.JWT.AdminRoles
Expand All @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
}

fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'")
return nil, unauthorizedError("User not allowed")
return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed")
}

func (a *API) extractBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
matches := bearerRegexp.FindStringSubmatch(authHeader)
if len(matches) != 2 {
return "", unauthorizedError("This endpoint requires a Bearer token")
return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token")
}

return matches[1], nil
Expand All @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err)
return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
}

return withToken(ctx, token), nil
Expand All @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
claims := getClaims(ctx)

if claims == nil {
return ctx, unauthorizedError("invalid token: missing claims")
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims")
}

if claims.Subject == "" {
return nil, unauthorizedError("invalid claim: missing sub claim")
return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim")
}

var user *models.User
if claims.Subject != "" {
userId, err := uuid.FromString(claims.Subject)
if err != nil {
return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err)
return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err)
}
user, err = models.FindUserByID(db, userId)
if err != nil {
if models.IsNotFoundError(err) {
return ctx, notFoundError(err.Error())
return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist")
}
return ctx, err
}
Expand All @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() {
sessionId, err := uuid.FromString(claims.SessionId)
if err != nil {
return ctx, err
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err)
}
session, err = models.FindSessionByID(db, sessionId, false)
if err != nil && !models.IsNotFoundError(err) {
return ctx, err
return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist")
}
ctx = withSession(ctx, session)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: unauthorizedError("invalid claim: missing sub claim"),
ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"),
ExpectedUser: nil,
},
{
Expand All @@ -118,7 +118,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"),
ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"),
ExpectedUser: nil,
},
{
Expand Down
76 changes: 76 additions & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package api

type ErrorCode = string

const (
// ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly.
ErrorCodeUnknown ErrorCode = "unknown"

// ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error.
ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure"

ErrorCodeValidationFailed ErrorCode = "validation_failed"
ErrorCodeBadJSON ErrorCode = "bad_json"
ErrorCodeEmailExists ErrorCode = "email_exists"
ErrorCodePhoneExists ErrorCode = "phone_exists"
ErrorCodeBadJWT ErrorCode = "bad_jwt"
ErrorCodeNotAdmin ErrorCode = "not_admin"
ErrorCodeNoAuthorization ErrorCode = "no_authorization"
ErrorCodeUserNotFound ErrorCode = "user_not_found"
ErrorCodeSessionNotFound ErrorCode = "session_not_found"
ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found"
ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired"
ErrorCodeSignupDisabled ErrorCode = "signup_disabled"
ErrorCodeUserBanned ErrorCode = "user_banned"
ErrorCodeOverEmailSendRate ErrorCode = "over_email_send_rate"
ErrorCodeOverSMSSendRate ErrorCode = "over_sms_send_rate"
ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification"
ErrorCodeInviteNotFound ErrorCode = "invite_not_found"
ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state"
ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback"
ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported"
ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience"
ErrorCodeLastIdentityNotDeletable ErrorCode = "last_identity_not_deletable"
ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable"
ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists"
ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled"
ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled"
ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors"
ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict"
ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found"
ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch"
ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired"
ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed"
ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected"
ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal"
ErrorCodeCaptchaFailed ErrorCode = "captcha_failed"
ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled"
ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled"
ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed"
ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed"
ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed"
ErrorCodeReauthNonceMissing ErrorCode = "reauth_nonce_missing"
ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found"
ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired"
ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found"
ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id"
ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email"
ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists"
ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found"
ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed"
ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists"
ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists"
ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch"
ErrorCodeConflict ErrorCode = "conflict"
ErrorCodeProviderDisabled ErrorCode = "provider_disabled"
ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed"
ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed"
ErrorCodeSamePassword ErrorCode = "same_password"
ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid"
ErrorCodeOTPExpired ErrorCode = "otp_expired"
ErrorCodeOTPDisabled ErrorCode = "otp_disabled"
ErrorCodeIdentityNotFound ErrorCode = "identity_not_found"
ErrorCodeWeakPassword ErrorCode = "weak_password"
ErrorCodeOverRequestRate ErrorCode = "over_request_rate"
ErrorBadCodeVerifier ErrorCode = "bad_code_verifier"
)
Loading

0 comments on commit 13a889c

Please sign in to comment.