Skip to content

Commit

Permalink
fix: change back to use generateFlowState
Browse files Browse the repository at this point in the history
  • Loading branch information
joel authored and joel committed Mar 4, 2024
1 parent 8827b71 commit e98eb11
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 30 deletions.
3 changes: 2 additions & 1 deletion internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ

flowStateID := ""
if isPKCEFlow(flowType) {
codeChallengeMethodType, err := models.MapCodeChallengeMethod(codeChallengeMethod)
codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
if err != nil {
return "", err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth, nil)

if err := a.db.Create(flowState); err != nil {
return "", err
}
Expand Down
10 changes: 7 additions & 3 deletions internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {

return sendJSON(w, http.StatusOK, make(map[string]string))
}
flowState, err := generateFlowStateIfPKCE(flowType, models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
var flowState *models.FlowState

if isPKCEFlow(flowType) {
flowState, err = generateFlowState(flowType, models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
Expand Down
7 changes: 2 additions & 5 deletions internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
}
}

func generateFlowStateIfPKCE(flowType models.FlowType, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
if !isPKCEFlow(flowType) {
return nil, nil
}
codeChallengeMethod, err := models.MapCodeChallengeMethod(codeChallengeMethodParam)
func generateFlowState(flowType models.FlowType, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
Expand Down
10 changes: 6 additions & 4 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
return internalServerError("Unable to process request").WithInternalError(err)
}

flowState, err := generateFlowStateIfPKCE(flowType, models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
return err
var flowState *models.FlowState
if isPKCEFlow(flowType) {
flowState, err = generateFlowState(flowType, models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID))
if err != nil {
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
if codeChallengeMethod, err = models.MapCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
if codeChallengeMethod, err = models.ParseCodeChallengeMethod(params.CodeChallengeMethod); err != nil {
return err
}
}
Expand Down
17 changes: 9 additions & 8 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,17 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
flowType := getFlowFromChallenge(params.CodeChallenge)
var flowStateID *uuid.UUID
flowStateID = nil
flowState, err := generateFlowStateIfPKCE(flowType, models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}

if flowState != nil {
if err := a.db.Create(flowState); err != nil {
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(flowType, models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return err
}
flowStateID = &flowState.ID
if flowState != nil {
if err := a.db.Create(flowState); err != nil {
return err
}
flowStateID = &flowState.ID
}
}

var ssoProvider *models.SSOProvider
Expand Down
16 changes: 9 additions & 7 deletions internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,16 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
mailer := a.Mailer(ctx)
referrer := utilities.GetReferrer(r, config)
flowType := getFlowFromChallenge(params.CodeChallenge)
flowState, err := generateFlowStateIfPKCE(flowType, models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(flowType, models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID)
if err != nil {
return err
}

if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
if flowState != nil {
if terr := tx.Create(flowState); terr != nil {
return terr
}
}
}
externalURL := getExternalHost(ctx)
Expand Down
2 changes: 1 addition & 1 deletion internal/models/flow_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (codeChallengeMethod CodeChallengeMethod) String() string {
return ""
}

func MapCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) {
func ParseCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) {
switch strings.ToLower(codeChallengeMethod) {
case "s256":
return SHA256, nil
Expand Down

0 comments on commit e98eb11

Please sign in to comment.