From 170bceb08158d436290878b15b2ecff2251ca13c Mon Sep 17 00:00:00 2001 From: joel Date: Wed, 21 Feb 2024 16:07:24 +0800 Subject: [PATCH] feat: refactor PKCE Co-authored-by: Stojan Dimitrovski --- internal/api/external.go | 8 +++----- internal/api/magic_link.go | 17 ++++++++++------- internal/api/pkce.go | 11 +++++++++++ internal/api/recover.go | 16 ++++++++++------ internal/api/signup.go | 5 +++-- internal/api/sso.go | 8 ++------ internal/api/token_test.go | 3 +-- internal/api/user.go | 9 +++++---- internal/api/verify_test.go | 4 ++-- internal/models/flow_state.go | 18 ++---------------- 10 files changed, 49 insertions(+), 50 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index a47d201dc..a1709978a 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -79,15 +79,13 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ flowType := getFlowFromChallenge(codeChallenge) flowStateID := "" - if flowType == models.PKCEFlow { + if isPKCEFlow(flowType) { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { return "", err } - flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) - if err != nil { - return "", err - } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth, nil) + if err := a.db.Create(flowState); err != nil { return "", err } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index e1b12caaf..c879470e1 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -125,18 +125,21 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, make(map[string]string)) } + var flowState *models.FlowState + + if isPKCEFlow(flowType) { + flowState, err = generateFlowState(flowType, models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if err != nil { + return err + } + } err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { return terr } - - if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) - if terr != nil { - return terr - } - if terr := models.NewFlowStateWithUserID(tx, models.MagicLink.String(), params.CodeChallenge, codeChallengeMethod, models.MagicLink, &user.ID); terr != nil { + if flowState != nil { + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index a186aa464..48f3f6606 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -4,6 +4,7 @@ import ( "regexp" "time" + "github.com/gofrs/uuid" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" ) @@ -78,3 +79,13 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType { return models.ImplicitFlow } } + +func generateFlowState(flowType models.FlowType, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { + codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam) + if err != nil { + return nil, err + } + flowState := models.NewFlowState(authenticationMethod.String(), codeChallenge, codeChallengeMethod, authenticationMethod, userID) + return flowState, nil + +} diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a5757565..991758bfd 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -62,6 +62,13 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } return internalServerError("Unable to process request").WithInternalError(err) } + var flowState *models.FlowState + if isPKCEFlow(flowType) { + flowState, err = generateFlowState(flowType, models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)) + if err != nil { + return err + } + } err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { @@ -69,15 +76,12 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } mailer := a.Mailer(ctx) referrer := utilities.GetReferrer(r, config) - if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) - if terr != nil { - return terr - } - if terr := models.NewFlowStateWithUserID(tx, models.Recovery.String(), params.CodeChallenge, codeChallengeMethod, models.Recovery, &(user.ID)); terr != nil { + if flowState != nil { + if terr := tx.Create(flowState); terr != nil { return terr } } + externalURL := getExternalHost(ctx) return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, flowType) }) diff --git a/internal/api/signup.go b/internal/api/signup.go index d7f289901..547f7590d 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -227,8 +227,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }); terr != nil { return terr } - if ok := isPKCEFlow(flowType); ok { - if terr := models.NewFlowStateWithUserID(tx, params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID); terr != nil { + if isPKCEFlow(flowType) { + flowState := models.NewFlowState(params.Provider, params.CodeChallenge, codeChallengeMethod, models.EmailSignup, &user.ID) + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/sso.go b/internal/api/sso.go index d93ff82dc..e8208ee5a 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -66,12 +66,8 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { flowType := getFlowFromChallenge(params.CodeChallenge) var flowStateID *uuid.UUID flowStateID = nil - if flowType == models.PKCEFlow { - codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) - if err != nil { - return err - } - flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(flowType, models.SSOSAML, codeChallengeMethod, codeChallenge, nil) if err != nil { return err } diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 0c8cc2377..b12a79a8e 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { invalidVerifier := codeVerifier + "123" codeChallenge := sha256.Sum256([]byte(codeVerifier)) challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) - flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) - require.NoError(ts.T(), err) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil) flowState.AuthCode = authCode require.NoError(ts.T(), ts.API.db.Create(flowState)) cases := []struct { diff --git a/internal/api/user.go b/internal/api/user.go index 73991b464..20d69c823 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -196,11 +196,12 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { referrer := utilities.GetReferrer(r, config) flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { - codeChallengeMethod, terr := models.ParseCodeChallengeMethod(params.CodeChallengeMethod) - if terr != nil { - return terr + flowState, err := generateFlowState(flowType, models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if err != nil { + return err } - if terr := models.NewFlowStateWithUserID(tx, models.EmailChange.String(), params.CodeChallenge, codeChallengeMethod, models.EmailChange, &user.ID); terr != nil { + + if terr := tx.Create(flowState); terr != nil { return terr } } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 1cdb43ba9..0a44b19e6 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -661,8 +661,8 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type - err := models.NewFlowStateWithUserID(ts.API.db, c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) - require.NoError(ts.T(), err) + flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) + require.NoError(ts.T(), ts.API.db.Create(flowState)) requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 6aced0b59..04e880ec2 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -81,21 +81,7 @@ func (FlowState) TableName() string { return tableName } -func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) { - id := uuid.Must(uuid.NewV4()) - authCode := uuid.Must(uuid.NewV4()) - flowState := &FlowState{ - ID: id, - ProviderType: providerType, - CodeChallenge: codeChallenge, - CodeChallengeMethod: codeChallengeMethod.String(), - AuthCode: authCode.String(), - AuthenticationMethod: authenticationMethod.String(), - } - return flowState, nil -} - -func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error { +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState { id := uuid.Must(uuid.NewV4()) authCode := uuid.Must(uuid.NewV4()) flowState := &FlowState{ @@ -107,7 +93,7 @@ func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge AuthenticationMethod: authenticationMethod.String(), UserID: userID, } - return tx.Create(flowState) + return flowState } func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) {