Skip to content

Commit

Permalink
fix: patch type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
J0 committed Sep 25, 2024
1 parent ded02fb commit 235dc06
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions internal/crypto/password.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var ErrScryptMismatchedHashAndPassword = errors.New("crypto: scrypt hash and pas

// argon2HashRegexp https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#argon2-encoding
var argon2HashRegexp = regexp.MustCompile("^[$](?P<alg>argon2(d|i|id))[$]v=(?P<v>(16|19))[$]m=(?P<m>[0-9]+),t=(?P<t>[0-9]+),p=(?P<p>[0-9]+)(,keyid=(?P<keyid>[^,]+))?(,data=(?P<data>[^$]+))?[$](?P<salt>[^$]+)[$](?P<hash>.+)$")
var scryptHashRegexp = regexp.MustCompile(`^\$scrypt\$ln=(?P<n>[0-9]+),r=(?P<r>[0-9]+),p=(?P<p>[0-9]+)(,ss=(?P<ss>[^,]+))?(,sk=(?P<sk>[^,]+))?\$(?P<salt>[^$]+)\$(?P<hash>.+)$`)
var scryptHashRegexp = regexp.MustCompile(`^\$scrypt\$n=(?P<n>[0-9]+),r=(?P<r>[0-9]+),p=(?P<p>[0-9]+)(,ss=(?P<ss>[^,]+))?(,sk=(?P<sk>[^,]+))?\$(?P<salt>[^$]+)\$(?P<hash>.+)$`)

type Argon2HashInput struct {
alg string
Expand All @@ -72,14 +72,13 @@ type Argon2HashInput struct {
type ScryptHashInput struct {
alg string
v string
n int
r int
p int
keyLen int
salt []byte
rawHash []byte
memory uint64
rounds uint64
threads uint64
saltSeparator []byte // Optional: Salt separator used in Firebase-style scrypt
signerKey []byte // Optional: Signer key used in Firebase-style scrypt
salt []byte
rawHash []byte
}

func ParseScryptHash(hash string) (*ScryptHashInput, error) {
Expand All @@ -88,6 +87,8 @@ func ParseScryptHash(hash string) (*ScryptHashInput, error) {
return nil, errors.New("crypto: incorrect scrypt hash format")
}

alg := string(argon2HashRegexp.ExpandString(nil, "$alg", hash, submatch))
v := string(scryptHashRegexp.ExpandString(nil, "$v", hash, submatch))
n := string(scryptHashRegexp.ExpandString(nil, "$n", hash, submatch))
r := string(scryptHashRegexp.ExpandString(nil, "$r", hash, submatch))
p := string(scryptHashRegexp.ExpandString(nil, "$p", hash, submatch))
Expand All @@ -96,25 +97,33 @@ func ParseScryptHash(hash string) (*ScryptHashInput, error) {
saltB64 := string(scryptHashRegexp.ExpandString(nil, "$salt", hash, submatch))
hashB64 := string(scryptHashRegexp.ExpandString(nil, "$hash", hash, submatch))

nValue, err := strconv.Atoi(n)
if alg != "scrypt" {
return nil, fmt.Errorf("crypto: scrypt hash uses unsupported algorithm %q only scrypt supported", alg)
}
if v != "1" {
return nil, fmt.Errorf("crypto: scrypt hash uses unsupported version $q only version 1 is supported", v)
}

memory, err := strconv.ParseUint(n, 10, 32)
if err != nil {
return nil, fmt.Errorf("crypto: scrypt hash has invalid n parameter %q: %w", n, err)
return nil, fmt.Errorf("crypto: scrypt hash has invalid n parameter %q %w", memory, err)
}
if nValue <= 1 || (nValue&(nValue-1)) != 0 {

if memory <= 1 || (memory&(memory-1)) != 0 {
return nil, fmt.Errorf("crypto: scrypt hash has invalid n parameter %q: must be a power of 2 greater than 1", n)
}

rValue, err := strconv.Atoi(r)
rounds, err := strconv.ParseUint(r, 10, 64)
if err != nil {
return nil, fmt.Errorf("crypto: scrypt hash has invalid r parameter %q: %w", r, err)
}

pValue, err := strconv.Atoi(p)
threads, err := strconv.ParseUint(p, 10, 8)
if err != nil {
return nil, fmt.Errorf("crypto: scrypt hash has invalid p parameter %q: %w", p, err)
return nil, fmt.Errorf("crypto: argon2 hash has invalid p parameter %q %w", p, err)
}

if rValue*pValue >= 1<<30 {
if rounds*threads >= 1<<30 {
return nil, fmt.Errorf("crypto: scrypt hash has invalid r and p parameters: r * p must be < 2^30")
}

Expand Down Expand Up @@ -143,12 +152,11 @@ func ParseScryptHash(hash string) (*ScryptHashInput, error) {
}

input := &ScryptHashInput{
alg: "scrypt",
v: "1",
n: nValue,
r: rValue,
p: pValue,
keyLen: len(rawHash),
alg: alg,
v: v,
memory: memory,
rounds: rounds,
threads: threads,
salt: salt,
rawHash: rawHash,
saltSeparator: saltSeparator,
Expand All @@ -160,7 +168,6 @@ func ParseScryptHash(hash string) (*ScryptHashInput, error) {

func ParseArgon2Hash(hash string) (*Argon2HashInput, error) {
submatch := argon2HashRegexp.FindStringSubmatchIndex(hash)

if submatch == nil {
return nil, errors.New("crypto: incorrect argon2 hash format")
}
Expand Down Expand Up @@ -274,9 +281,9 @@ func compareHashAndPasswordScrypt(ctx context.Context, hash, password string) er
attributes := []attribute.KeyValue{
attribute.String("alg", input.alg),
attribute.String("v", input.v),
attribute.Int64("n", int64(input.n)),
attribute.Int64("r", int64(input.r)),
attribute.Int("p", int(input.p)),
attribute.Int64("n", int64(input.memory)),
attribute.Int64("r", int64(input.rounds)),
attribute.Int("p", int(input.threads)),
attribute.Int("len", len(input.rawHash)),
attribute.Bool("is_firebase", len(input.saltSeparator) > 0),
}
Expand All @@ -294,10 +301,10 @@ func compareHashAndPasswordScrypt(ctx context.Context, hash, password string) er
if len(input.saltSeparator) > 0 {
// Firebase-style scrypt
combinedSalt := append(input.salt, input.saltSeparator...)
derivedKey, err = firebaseScrypt([]byte(password), combinedSalt, input.signerKey, input.n, input.r, input.p, len(input.rawHash))
derivedKey, err = firebaseScrypt([]byte(password), combinedSalt, input.signerKey, input.memory, input.rounds, input.threads, len(input.rawHash))
} else {
// Standard scrypt
derivedKey, err = scrypt.Key([]byte(password), input.salt, input.n, input.r, input.p, len(input.rawHash))
derivedKey, err = scrypt.Key([]byte(password), input.salt, int(input.memory), int(input.rounds), int(input.threads), len(input.rawHash))
}
if err != nil {
return fmt.Errorf("failed to derive scrypt key: %w", err)
Expand All @@ -313,9 +320,9 @@ func compareHashAndPasswordScrypt(ctx context.Context, hash, password string) er
return nil
}

func firebaseScrypt(password, salt, signerKey []byte, N, r, p, keyLen int) ([]byte, error) {
func firebaseScrypt(password, salt, signerKey []byte, N, r, p uint64, keyLen int) ([]byte, error) {
// Step 1: Use standard scrypt to derive an intermediate key
intermediateKey, err := scrypt.Key(password, salt, N, r, p, 32)
intermediateKey, err := scrypt.Key(password, salt, int(N), int(r), int(p), 32)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 235dc06

Please sign in to comment.