Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: customAuthHook update REST response #1459

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 79 additions & 34 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,19 +301,20 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)
return sendJSON(w, http.StatusOK, token)
}

func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, jwt.Claims, error) {
config := a.config
aal, amr := models.AAL1.String(), []models.AMREntry{}
var finalClaims jwt.Claims
sid := ""
if sessionId != nil {
sid = sessionId.String()
session, terr := models.FindSessionByID(tx, *sessionId, false)
if terr != nil {
return "", 0, terr
return "", nil, terr
}
aal, amr, terr = session.CalculateAALAndAMR(tx)
if terr != nil {
return "", 0, terr
return "", nil, terr
}
}

Expand All @@ -338,6 +339,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u
AuthenticationMethodReference: amr,
}

finalClaims = claims
var token *jwt.Token
if config.Hook.CustomAccessToken.Enabled {
input := hooks.CustomAccessTokenInput{
Expand All @@ -350,14 +352,14 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u

err := a.invokeHook(ctx, &input, &output)
if err != nil {
return "", 0, err
return "", nil, err
}
goTrueClaims := jwt.MapClaims(output.Claims)

token = jwt.NewWithClaims(jwt.SigningMethodHS256, goTrueClaims)

finalClaims = jwt.MapClaims(output.Claims)
token = jwt.NewWithClaims(jwt.SigningMethodHS256, finalClaims)

} else {
token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token = jwt.NewWithClaims(jwt.SigningMethodHS256, finalClaims)
}

if config.JWT.KeyID != "" {
Expand All @@ -370,22 +372,23 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u

signed, err := token.SignedString([]byte(config.JWT.Secret))
if err != nil {
return "", 0, err
return "", nil, err
}

return signed, expiresAt, nil
return signed, finalClaims, nil
}



func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
config := a.config


now := time.Now()
user.LastSignInAt = &now

var tokenString string
var expiresAt int64
var tokenClaims jwt.Claims
var refreshToken *models.RefreshToken

err := conn.Transaction(func(tx *storage.Connection) error {
var terr error

Expand All @@ -399,7 +402,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
return terr
}

tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod)
tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, refreshToken.SessionId, authenticationMethod)
if terr != nil {
// Account for Hook Error
httpErr, ok := terr.(*HTTPError)
Expand All @@ -413,22 +416,14 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
if err != nil {
return nil, err
}

return &AccessTokenResponse{
Token: tokenString,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
ExpiresAt: expiresAt,
RefreshToken: refreshToken.Token,
User: user,
}, nil

return a.constructAccessTokenResponse(user, tokenClaims, tokenString, refreshToken.Token), nil
}

func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) {
ctx := r.Context()
config := a.config
var tokenString string
var expiresAt int64
var tokenClaims jwt.Claims
var refreshToken *models.RefreshToken
currentClaims := getClaims(ctx)
sessionId, err := uuid.FromString(currentClaims.SessionId)
Expand Down Expand Up @@ -464,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return err
}

tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn)
tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn)
if terr != nil {
httpErr, ok := terr.(*HTTPError)
if ok {
Expand All @@ -477,14 +472,8 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
if err != nil {
return nil, err
}
return &AccessTokenResponse{
Token: tokenString,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
ExpiresAt: expiresAt,
RefreshToken: refreshToken.Token,
User: user,
}, nil

return a.constructAccessTokenResponse(user, tokenClaims, tokenString, refreshToken.Token), nil
}

// setCookieTokens sets the access_token & refresh_token in the cookies
Expand Down Expand Up @@ -563,3 +552,59 @@ func validateTokenClaims(outputClaims map[string]interface{}) error {

return nil
}

func (a *API) constructAccessTokenResponse(user *models.User, claims jwt.Claims, token string, refreshToken string) *AccessTokenResponse {
config := a.config
var expiresIn int
var expiresAt int64
switch c := claims.(type) {
case *hooks.AccessTokenClaims:
return &AccessTokenResponse{
Token: token,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
ExpiresAt: c.StandardClaims.ExpiresAt,
RefreshToken: refreshToken,
User: user,
}
case jwt.MapClaims:

if exp, ok := c["exp"].(float64); ok {
// Check for expiresAt within StandardClaims
expiresAt = int64(exp)
expiresIn = int(expiresAt - time.Now().Unix())
}else{
expiresIn = config.JWT.Exp
expiresAt = time.Now().Unix() + int64(expiresIn)
}

// Assign values if the underlying type is map[string]interface{}
if userMetaData, ok := c["user_metadata"].(map[string]interface{}); ok {
user.UserMetaData = userMetaData
}

if appMetaData, ok := c["app_metadata"].(map[string]interface{}); ok {
user.AppMetaData = appMetaData
}
if role, ok := c["role"].(string); ok {
user.Role = role
}
if subject, ok := c["sub"].(uuid.UUID); ok { // Assuming "sub" is the key for subject
user.ID = subject
}
if audience, ok := c["aud"].(string); ok { // Assuming "aud" is the key for audience
user.Aud = audience
}
return &AccessTokenResponse{
Token: token,
TokenType: "bearer",
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
RefreshToken: refreshToken,
User: user,
}
default:

return nil
}
}
21 changes: 9 additions & 12 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"
"time"

"github.com/golang-jwt/jwt"

"github.com/supabase/auth/internal/metering"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
Expand Down Expand Up @@ -86,7 +88,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
// the connection pool does not get exhausted.

var tokenString string
var expiresAt int64
var tokenClaims jwt.Claims
var newTokenResponse *AccessTokenResponse

err = db.Transaction(func(tx *storage.Connection) error {
Expand Down Expand Up @@ -216,16 +218,16 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h

issuedToken = newToken
}


tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh)
tokenString, tokenClaims, terr = a.generateAccessToken(ctx, tx, user, issuedToken.SessionId, models.TokenRefresh)
if terr != nil {
httpErr, ok := terr.(*HTTPError)
if ok {
return httpErr
}
return internalServerError("error generating jwt token").WithInternalError(terr)
}

refreshedAt := a.Now()
session.RefreshedAt = &refreshedAt

Expand All @@ -247,23 +249,18 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError("failed to update session information").WithInternalError(terr)
}

newTokenResponse = &AccessTokenResponse{
Token: tokenString,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
ExpiresAt: expiresAt,
RefreshToken: issuedToken.Token,
User: user,
}
newTokenResponse = a.constructAccessTokenResponse(user, tokenClaims, tokenString, issuedToken.Token)

if terr = a.setCookieTokens(config, newTokenResponse, false, w); terr != nil {
return internalServerError("Failed to set JWT cookie. %s", terr)
}

return nil
})
if err == nil {
// success

metering.RecordLogin("token", user.ID)

return sendJSON(w, http.StatusOK, newTokenResponse)
}

Expand Down
76 changes: 75 additions & 1 deletion internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() {

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusOK, w.Code)
}

Expand Down Expand Up @@ -762,6 +763,79 @@ end; $$ language plpgsql;`,
}
}


func (ts *TokenTestSuite) TestCustomAccessTokenInResponseUser() {
type customAccessTokenTestcase struct {
desc string
grantType string
requestBody map[string]interface{}
}
var hookFunctionSQL string = ` create or replace function custom_access_token_add_claim(input jsonb)
returns jsonb
language plpgsql
as $$
declare
result jsonb;
begin
input := jsonb_set(input, '{claims,user_metadata,new_metadata}', '"newvalue"', true);
result := jsonb_build_object('claims', input->'claims');
return result;
end;
$$;`
cases := []customAccessTokenTestcase{
{
desc: "check user is updated in refresh token grant type",
grantType: "refresh_token",
requestBody: map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
},

}, {
desc: "check user is updated in password grant type",
grantType: "password",
requestBody: map[string]interface{}{
"email": "test@example.com",
"password": "password",
},

},
}
for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.CustomAccessToken.Enabled = true
ts.Config.Hook.CustomAccessToken.URI = "pg-functions://postgres/auth/custom_access_token_add_claim"
require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint())

err := ts.API.db.RawQuery(hookFunctionSQL).Exec()
require.NoError(t, err)

var buffer bytes.Buffer
require.NoError(t, json.NewEncoder(&buffer).Encode(c.requestBody))

url := "http://localhost/token?grant_type=" + c.grantType
req := httptest.NewRequest(http.MethodPost, url, &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

var userResponse struct {
User *models.User `json:"user"`
}

require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&userResponse))

value, exists := userResponse.User.UserMetaData["new_metadata"]
require.True(t, exists, "Key 'new_metadata' does not exist in UserMetadata")
require.Equal(t, "newvalue", value, "The value of 'new_metadata' is not 'newvalue'")

cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName)
require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec())
ts.Config.Hook.CustomAccessToken.Enabled = false
})
}
}

func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() {

companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil)
Expand Down Expand Up @@ -855,4 +929,4 @@ $$;`
ts.Config.Hook.CustomAccessToken.Enabled = false
})
}
}
}