diff --git a/internal/api/token.go b/internal/api/token.go index 7f084609a..706c6b33a 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -301,19 +301,20 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return sendJSON(w, http.StatusOK, token) } -func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { +func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, jwt.Claims, error) { config := a.config aal, amr := models.AAL1.String(), []models.AMREntry{} + var finalClaims jwt.Claims sid := "" if sessionId != nil { sid = sessionId.String() session, terr := models.FindSessionByID(tx, *sessionId, false) if terr != nil { - return "", 0, terr + return "", nil, terr } aal, amr, terr = session.CalculateAALAndAMR(tx) if terr != nil { - return "", 0, terr + return "", nil, terr } } @@ -338,6 +339,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u AuthenticationMethodReference: amr, } + finalClaims = claims var token *jwt.Token if config.Hook.CustomAccessToken.Enabled { input := hooks.CustomAccessTokenInput{ @@ -350,14 +352,14 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u err := a.invokeHook(ctx, &input, &output) if err != nil { - return "", 0, err + return "", nil, err } - goTrueClaims := jwt.MapClaims(output.Claims) - - token = jwt.NewWithClaims(jwt.SigningMethodHS256, goTrueClaims) + + finalClaims = jwt.MapClaims(output.Claims) + token = jwt.NewWithClaims(jwt.SigningMethodHS256, finalClaims) } else { - token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token = jwt.NewWithClaims(jwt.SigningMethodHS256, finalClaims) } if config.JWT.KeyID != "" { @@ -370,22 +372,23 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u signed, err := token.SignedString([]byte(config.JWT.Secret)) if err != nil { - return "", 0, err + return "", nil, err } - return signed, expiresAt, nil + return signed, finalClaims, nil } + + func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { - config := a.config + now := time.Now() user.LastSignInAt = &now var tokenString string - var expiresAt int64 + var tokenClaims jwt.Claims var refreshToken *models.RefreshToken - err := conn.Transaction(func(tx *storage.Connection) error { var terr error @@ -399,7 +402,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u return terr } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod) + tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod) if terr != nil { // Account for Hook Error httpErr, ok := terr.(*HTTPError) @@ -413,22 +416,14 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u if err != nil { return nil, err } - - return &AccessTokenResponse{ - Token: tokenString, - TokenType: "bearer", - ExpiresIn: config.JWT.Exp, - ExpiresAt: expiresAt, - RefreshToken: refreshToken.Token, - User: user, - }, nil + + return a.constructAccessTokenResponse(user, tokenClaims, tokenString, refreshToken.Token), nil } func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { ctx := r.Context() - config := a.config var tokenString string - var expiresAt int64 + var tokenClaims jwt.Claims var refreshToken *models.RefreshToken currentClaims := getClaims(ctx) sessionId, err := uuid.FromString(currentClaims.SessionId) @@ -464,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) + tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { @@ -477,14 +472,8 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, if err != nil { return nil, err } - return &AccessTokenResponse{ - Token: tokenString, - TokenType: "bearer", - ExpiresIn: config.JWT.Exp, - ExpiresAt: expiresAt, - RefreshToken: refreshToken.Token, - User: user, - }, nil + + return a.constructAccessTokenResponse(user, tokenClaims, tokenString, refreshToken.Token), nil } // setCookieTokens sets the access_token & refresh_token in the cookies @@ -563,3 +552,59 @@ func validateTokenClaims(outputClaims map[string]interface{}) error { return nil } + +func (a *API) constructAccessTokenResponse(user *models.User, claims jwt.Claims, token string, refreshToken string) *AccessTokenResponse { + config := a.config + var expiresIn int + var expiresAt int64 + switch c := claims.(type) { + case *hooks.AccessTokenClaims: + return &AccessTokenResponse{ + Token: token, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: c.StandardClaims.ExpiresAt, + RefreshToken: refreshToken, + User: user, + } + case jwt.MapClaims: + + if exp, ok := c["exp"].(float64); ok { + // Check for expiresAt within StandardClaims + expiresAt = int64(exp) + expiresIn = int(expiresAt - time.Now().Unix()) + }else{ + expiresIn = config.JWT.Exp + expiresAt = time.Now().Unix() + int64(expiresIn) + } + + // Assign values if the underlying type is map[string]interface{} + if userMetaData, ok := c["user_metadata"].(map[string]interface{}); ok { + user.UserMetaData = userMetaData + } + + if appMetaData, ok := c["app_metadata"].(map[string]interface{}); ok { + user.AppMetaData = appMetaData + } + if role, ok := c["role"].(string); ok { + user.Role = role + } + if subject, ok := c["sub"].(uuid.UUID); ok { // Assuming "sub" is the key for subject + user.ID = subject + } + if audience, ok := c["aud"].(string); ok { // Assuming "aud" is the key for audience + user.Aud = audience + } + return &AccessTokenResponse{ + Token: token, + TokenType: "bearer", + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + RefreshToken: refreshToken, + User: user, + } + default: + + return nil + } +} \ No newline at end of file diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 1cd665346..f6d370715 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + "github.com/golang-jwt/jwt" + "github.com/supabase/auth/internal/metering" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -86,7 +88,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h // the connection pool does not get exhausted. var tokenString string - var expiresAt int64 + var tokenClaims jwt.Claims var newTokenResponse *AccessTokenResponse err = db.Transaction(func(tx *storage.Connection) error { @@ -216,8 +218,9 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h issuedToken = newToken } + - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh) + tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { @@ -225,7 +228,6 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h } return internalServerError("error generating jwt token").WithInternalError(terr) } - refreshedAt := a.Now() session.RefreshedAt = &refreshedAt @@ -247,23 +249,18 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return internalServerError("failed to update session information").WithInternalError(terr) } - newTokenResponse = &AccessTokenResponse{ - Token: tokenString, - TokenType: "bearer", - ExpiresIn: config.JWT.Exp, - ExpiresAt: expiresAt, - RefreshToken: issuedToken.Token, - User: user, - } + newTokenResponse = a.constructAccessTokenResponse(user, tokenClaims, tokenString, issuedToken.Token) + if terr = a.setCookieTokens(config, newTokenResponse, false, w); terr != nil { return internalServerError("Failed to set JWT cookie. %s", terr) } - return nil }) if err == nil { // success + metering.RecordLogin("token", user.ID) + return sendJSON(w, http.StatusOK, newTokenResponse) } diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 6a8acf6d3..19de6ff3f 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -530,6 +530,7 @@ func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) } @@ -762,6 +763,79 @@ end; $$ language plpgsql;`, } } + +func (ts *TokenTestSuite) TestCustomAccessTokenInResponseUser() { + type customAccessTokenTestcase struct { + desc string + grantType string + requestBody map[string]interface{} + } + var hookFunctionSQL string = ` create or replace function custom_access_token_add_claim(input jsonb) + returns jsonb + language plpgsql + as $$ + declare + result jsonb; + begin + input := jsonb_set(input, '{claims,user_metadata,new_metadata}', '"newvalue"', true); + result := jsonb_build_object('claims', input->'claims'); + return result; + end; + $$;` + cases := []customAccessTokenTestcase{ + { + desc: "check user is updated in refresh token grant type", + grantType: "refresh_token", + requestBody: map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + }, + + }, { + desc: "check user is updated in password grant type", + grantType: "password", + requestBody: map[string]interface{}{ + "email": "test@example.com", + "password": "password", + }, + + }, + } + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = "pg-functions://postgres/auth/custom_access_token_add_claim" + require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(hookFunctionSQL).Exec() + require.NoError(t, err) + + var buffer bytes.Buffer + require.NoError(t, json.NewEncoder(&buffer).Encode(c.requestBody)) + + url := "http://localhost/token?grant_type=" + c.grantType + req := httptest.NewRequest(http.MethodPost, url, &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + var userResponse struct { + User *models.User `json:"user"` + } + + require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&userResponse)) + + value, exists := userResponse.User.UserMetaData["new_metadata"] + require.True(t, exists, "Key 'new_metadata' does not exist in UserMetadata") + require.Equal(t, "newvalue", value, "The value of 'new_metadata' is not 'newvalue'") + + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false + }) + } +} + func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() { companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil) @@ -855,4 +929,4 @@ $$;` ts.Config.Hook.CustomAccessToken.Enabled = false }) } -} +} \ No newline at end of file