Skip to content

Commit

Permalink
feat: add error codes
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Feb 24, 2024
1 parent 1ea56b6 commit 7c1b130
Show file tree
Hide file tree
Showing 51 changed files with 811 additions and 501 deletions.
36 changes: 18 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,

userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
if err != nil {
return nil, badRequestError("user_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID")
}

observability.LogEntrySetField(r, "user_id", userID)

u, err := models.FindUserByID(db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("User not found")
return nil, notFoundError(ErrorCodeUserNotFound, "User not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
Expand All @@ -69,15 +69,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) {
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, badRequestError("factor_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

f, err := models.FindFactorByFactorID(a.db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Factor not found")
return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
}
Expand All @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {

body, err := getBodyBytes(r)
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
return nil, internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
return nil, badRequestError(ErrorCodeBadJSON, "Could not decode admin user params").WithInternalError(err)
}

return &params, nil
Expand All @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {

pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError("Bad Sort Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
}

filter := r.URL.Query().Get("filter")
Expand Down Expand Up @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}

if params.Email == "" && params.Phone == "" {
return unprocessableEntityError("Cannot create a user without either an email or phone")
return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
}

var providers []string
Expand All @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
return internalServerError("Database error checking email").WithInternalError(err)
} else if user != nil {
return unprocessableEntityError(DuplicateEmailMsg)
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
}
providers = append(providers, "email")
}
Expand All @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
return internalServerError("Database error checking phone").WithInternalError(err)
} else if exists {
return unprocessableEntityError("Phone number already registered by another user")
return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user")
}
providers = append(providers, "phone")
}
Expand Down Expand Up @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
params := &adminUserDeleteParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}
if len(body) > 0 {
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err)
}
} else {
params.ShouldSoftDelete = false
Expand Down Expand Up @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
return badRequestError(ErrorCodeBadJSON, "Could not read factor update params: %v", err).WithInternalError(err)
}

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
}
if params.FactorType != "" {
if params.FactorType != models.TOTP {
return badRequestError("Factor Type not valid")
return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid")
}
if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil {
return terr
Expand Down
35 changes: 35 additions & 0 deletions internal/api/apiversions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package api

import (
"time"
)

const APIVersionHeaderName = "X-Supabase-Api-Version"

type APIVersion = time.Time

var (
APIVersionInitial = time.Time{}
APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC)
)

func DetermineClosestAPIVersion(date string) (APIVersion, error) {
if date == "" {
return APIVersionInitial, nil
}

parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC)
if err != nil {
return APIVersionInitial, err
}

if parsed.Compare(APIVersion20240101) >= 0 {
return APIVersion20240101, nil
}

return APIVersionInitial, nil
}

func FormatAPIVersion(apiVersion APIVersion) string {
return apiVersion.Format("2006-01-02")
}
29 changes: 29 additions & 0 deletions internal/api/apiversions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package api

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDetermineClosestAPIVersion(t *testing.T) {
version, err := DetermineClosestAPIVersion("")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("Not a date")
require.Error(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2023-12-31")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2024-01-01")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)

version, err = DetermineClosestAPIVersion("2024-01-02")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)
}
4 changes: 2 additions & 2 deletions internal/api/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
// aud := a.requestAud(ctx, r)
pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
}

var col []string
Expand All @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
qparts := strings.SplitN(q, ":", 2)
col, exists = filterColumnMap[qparts[0]]
if !exists || len(qparts) < 2 {
return badRequestError("Invalid query scope: %s", q)
return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q)
}
qval = qparts[1]
}
Expand Down
20 changes: 10 additions & 10 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
claims := getClaims(ctx)
if claims == nil {
fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token")
return nil, unauthorizedError("Invalid token")
return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token")
}

adminRoles := a.config.JWT.AdminRoles
Expand All @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
}

fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'")
return nil, unauthorizedError("User not allowed")
return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed")
}

func (a *API) extractBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
matches := bearerRegexp.FindStringSubmatch(authHeader)
if len(matches) != 2 {
return "", unauthorizedError("This endpoint requires a Bearer token")
return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token")
}

return matches[1], nil
Expand All @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err)
return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
}

return withToken(ctx, token), nil
Expand All @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
claims := getClaims(ctx)

if claims == nil {
return ctx, unauthorizedError("invalid token: missing claims")
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims")
}

if claims.Subject == "" {
return nil, unauthorizedError("invalid claim: missing sub claim")
return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim")
}

var user *models.User
if claims.Subject != "" {
userId, err := uuid.FromString(claims.Subject)
if err != nil {
return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err)
return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err)
}
user, err = models.FindUserByID(db, userId)
if err != nil {
if models.IsNotFoundError(err) {
return ctx, notFoundError(err.Error())
return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist")
}
return ctx, err
}
Expand All @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() {
sessionId, err := uuid.FromString(claims.SessionId)
if err != nil {
return ctx, err
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err)
}
session, err = models.FindSessionByID(db, sessionId, false)
if err != nil && !models.IsNotFoundError(err) {
return ctx, err
return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist")
}
ctx = withSession(ctx, session)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: unauthorizedError("invalid claim: missing sub claim"),
ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"),
ExpectedUser: nil,
},
{
Expand All @@ -118,7 +118,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"),
ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"),
ExpectedUser: nil,
},
{
Expand Down
Loading

0 comments on commit 7c1b130

Please sign in to comment.