diff --git a/go.mod b/go.mod index b39aac6db..665e2f4c6 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,6 @@ module github.com/supabase/auth require ( github.com/Masterminds/semver/v3 v3.1.1 // indirect - github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc diff --git a/go.sum b/go.sum index 4ffdf5198..0582152c4 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,6 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/XSAM/otelsql v0.16.0 h1:pOqeHGYCJmP5ezW0OvAGA+zzdgW/sV8nLHTxVnPgiXU= github.com/XSAM/otelsql v0.16.0/go.mod h1:DpO7NCSeqQdr23nU0yapjR3jGx2OdO/PihPRG+/PV0Y= -github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 h1:uz4N2yHL4MF8vZX+36n+tcxeUf8D/gL4aJkyouhDw4A= -github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40/go.mod h1:dytw+5qs+pdi61fO/S4OmXR7AuEq/HvNCuG03KxQHT4= github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 94c422f0a..0d4016102 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -7,8 +7,8 @@ import ( "net/http" "net/url" - "github.com/aaronarduino/goqrsvg" svg "github.com/ajstarks/svgo" + "github.com/boombuler/barcode" "github.com/boombuler/barcode/qr" "github.com/gofrs/uuid" "github.com/pquerna/otp/totp" @@ -19,25 +19,23 @@ import ( "github.com/supabase/auth/internal/utilities" ) -const DefaultQRSize = 3 - type EnrollFactorParams struct { FriendlyName string `json:"friendly_name"` FactorType string `json:"factor_type"` Issuer string `json:"issuer"` } -type TOTPObject struct { +type TOTPSetup struct { QRCode string `json:"qr_code"` Secret string `json:"secret"` URI string `json:"uri"` } type EnrollFactorResponse struct { - ID uuid.UUID `json:"id"` - Type string `json:"type"` - FriendlyName string `json:"friendly_name"` - TOTP TOTPObject `json:"totp,omitempty"` + ID uuid.UUID `json:"id"` + Type string `json:"type"` + FriendlyName string `json:"friendly_name"` + TOTP TOTPSetup `json:"totp,omitempty"` } type VerifyFactorParams struct { @@ -114,25 +112,12 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if numVerifiedFactors > 0 && !session.IsAAL2() { return forbiddenError("AAL2 required to enroll a new factor") } - - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: issuer, - AccountName: user.GetEmail(), - }) + sideLength := 256 + totpSetup, err := generateQRCode(issuer, user.GetEmail(), sideLength) if err != nil { - return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) - } - var buf bytes.Buffer - svgData := svg.New(&buf) - qrCode, _ := qr.Encode(key.String(), qr.M, qr.Auto) - qs := goqrsvg.NewQrSVG(qrCode, DefaultQRSize) - qs.StartQrSVG(svgData) - if err = qs.WriteQrSVG(svgData); err != nil { - return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + return err } - svgData.End() - - factor := models.NewFactor(user, params.FriendlyName, params.FactorType, models.FactorStateUnverified, key.Secret()) + factor := models.NewFactor(user, params.FriendlyName, params.FactorType, models.FactorStateUnverified, totpSetup.Secret) err = a.db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { @@ -158,11 +143,10 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { ID: factor.ID, Type: models.TOTP, FriendlyName: factor.FriendlyName, - TOTP: TOTPObject{ - // See: https://css-tricks.com/probably-dont-base64-svg/ - QRCode: buf.String(), - Secret: factor.Secret, - URI: key.URL(), + TOTP: TOTPSetup{ + QRCode: totpSetup.QRCode, + Secret: totpSetup.Secret, + URI: totpSetup.URI, }, }) } @@ -223,9 +207,6 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID) if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } return internalServerError("Database error finding Challenge").WithInternalError(err) } @@ -291,13 +272,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if terr = challenge.Verify(tx); terr != nil { return terr } - if !factor.IsVerified() { - if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { - return terr - } - } - user, terr = models.FindUserByID(tx, user.ID) - if terr != nil { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { return terr } token, terr = a.updateMFASessionAndClaims(r, tx, user, models.TOTPSignIn, models.GrantParams{ @@ -368,3 +343,45 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { ID: factor.ID, }) } + +func generateQRCode(issuer string, account string, sideLength int) (*TOTPSetup, error) { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: account, + }) + if err != nil { + return nil, err + } + + // See: https://pkg.go.dev/github.com/boombuler/barcode@v1.0.1/qr#ErrorCorrectionLevel + qrCode, err := qr.Encode(key.URL(), qr.Q, qr.Auto) + if err != nil { + return nil, err + } + + qrCode, err = barcode.Scale(qrCode, sideLength, sideLength) + if err != nil { + return nil, err + } + + // Create a buffer to hold SVG data. + var b bytes.Buffer + canvas := svg.New(&b) + + // Start SVG generation. + canvas.Start(qrCode.Bounds().Dx(), qrCode.Bounds().Dy()) + for x := 0; x < qrCode.Bounds().Dx(); x++ { + for y := 0; y < qrCode.Bounds().Dy(); y++ { + r, g, b, _ := qrCode.At(x, y).RGBA() + color := fmt.Sprintf("rgb(%d,%d,%d)", r>>8, g>>8, b>>8) + canvas.Rect(x, y, 1, 1, "fill:"+color) + } + } + canvas.End() + + return &TOTPSetup{ + QRCode: b.String(), + Secret: key.Secret(), + URI: key.URL(), + }, nil +} diff --git a/internal/models/challenge.go b/internal/models/challenge.go index 99758e63b..a60f647af 100644 --- a/internal/models/challenge.go +++ b/internal/models/challenge.go @@ -36,8 +36,9 @@ func NewChallenge(factor *Factor, ipAddress string) *Challenge { func FindChallengeByChallengeID(tx *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { challenge, err := findChallenge(tx, "id = ?", challengeID) if err != nil { - return nil, ChallengeNotFoundError{} + return nil, err } + return challenge, nil } @@ -62,7 +63,7 @@ func findChallenge(tx *storage.Connection, query string, args ...interface{}) (* if errors.Cause(err) == sql.ErrNoRows { return nil, ChallengeNotFoundError{} } - return nil, errors.Wrap(err, "error finding challenge") + return nil, err } return obj, nil }