From bcd548f7809ad59a3e22461d6eac224833b96623 Mon Sep 17 00:00:00 2001 From: donutnomad Date: Fri, 1 Nov 2024 15:41:32 +0800 Subject: [PATCH] fix: fix bug --- xasn1/mod.go | 172 ++++++++++++++++++++++++++--- xasn1/mod_test.go | 27 +++++ xed25519/pubkey.go | 3 +- xsecp256k1/kit.go | 5 +- xx509/publicKey.go | 264 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 455 insertions(+), 16 deletions(-) create mode 100644 xasn1/mod_test.go create mode 100644 xx509/publicKey.go diff --git a/xasn1/mod.go b/xasn1/mod.go index 5d6846d..4a35c9b 100644 --- a/xasn1/mod.go +++ b/xasn1/mod.go @@ -4,12 +4,20 @@ import ( "crypto/x509/pkix" "encoding/asn1" "errors" + "github.com/samber/lo" "golang.org/x/crypto/cryptobyte" asn11 "golang.org/x/crypto/cryptobyte/asn1" "math" "math/big" ) +// A StructuralError suggests that the ASN.1 data is valid, but the Go type +// which is receiving it doesn't match. +type StructuralError = asn1.StructuralError + +// A SyntaxError suggests that the ASN.1 data is invalid. +type SyntaxError = asn1.SyntaxError + // ParseBase128Int parses a base-128 encoded int from the given offset in the // given byte slice. It returns the value and the new offset. func ParseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) { @@ -19,7 +27,7 @@ func ParseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) // 5 * 7 bits per byte == 35 bits of data // Thus the representation is either non-minimal or too large for an int32 if shifted == 5 { - panic("base 128 integer too large") + err = StructuralError{Msg: "base 128 integer too large"} return } ret64 <<= 7 @@ -27,7 +35,7 @@ func ParseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) // integers should be minimally encoded, so the leading octet should // never be 0x80 if shifted == 0 && b == 0x80 { - panic("\"integer is not minimally encoded\"") + err = SyntaxError{Msg: "integer is not minimally encoded"} return } ret64 |= int64(b & 0x7f) @@ -36,12 +44,96 @@ func ParseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) ret = int(ret64) // Ensure that the returned value fits in an int on all platforms if ret64 > math.MaxInt32 { - panic("base 128 integer too large") + err = StructuralError{Msg: "base 128 integer too large"} + } + return + } + } + err = SyntaxError{Msg: "truncated base 128 integer"} + return +} + +type TagAndLength struct { + Class, Tag, Length int + IsCompound bool +} + +// ParseTagAndLength parses an ASN.1 tag and length pair from the given offset +// into a byte slice. It returns the parsed data and the new offset. SET and +// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we +// don't distinguish between ordered and unordered objects in this code. +func ParseTagAndLength(bytes []byte, initOffset int) (ret TagAndLength, offset int, err error) { + offset = initOffset + // parseTagAndLength should not be called without at least a single + // byte to read. Thus this check is for robustness: + if offset >= len(bytes) { + err = errors.New("asn1: internal error in parseTagAndLength") + return + } + b := bytes[offset] + offset++ + ret.Class = int(b >> 6) + ret.IsCompound = b&0x20 == 0x20 + ret.Tag = int(b & 0x1f) + + // If the bottom five bits are set, then the tag number is actually base 128 + // encoded afterwards + if ret.Tag == 0x1f { + ret.Tag, offset, err = ParseBase128Int(bytes, offset) + if err != nil { + return + } + // Tags should be encoded in minimal form. + if ret.Tag < 0x1f { + err = SyntaxError{Msg: "non-minimal tag"} + return + } + } + if offset >= len(bytes) { + err = SyntaxError{Msg: "truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if b&0x80 == 0 { + // The length is encoded in the bottom 7 bits. + ret.Length = int(b & 0x7f) + } else { + // Bottom 7 bits give the number of length bytes to follow. + numBytes := int(b & 0x7f) + if numBytes == 0 { + err = SyntaxError{Msg: "indefinite length found (not DER)"} + return + } + ret.Length = 0 + for i := 0; i < numBytes; i++ { + if offset >= len(bytes) { + err = SyntaxError{Msg: "truncated tag or length"} + return + } + b = bytes[offset] + offset++ + if ret.Length >= 1<<23 { + // We can't shift ret.length up without + // overflowing. + err = StructuralError{Msg: "length too large"} + return } + ret.Length <<= 8 + ret.Length |= int(b) + if ret.Length == 0 { + // DER requires that lengths be minimal. + err = StructuralError{Msg: "superfluous leading zeros in length"} + return + } + } + // Short lengths must be encoded in short form. + if ret.Length < 0x80 { + err = StructuralError{Msg: "non-minimal length"} return } } - panic("base 128 integer too large") + return } @@ -83,24 +175,55 @@ func ParseObjectIdentifier(bytes []byte) (s []int, err error) { return } -type publicKeyInfo struct { +func FindTypeList(bs []byte, offset int, target int) [][]byte { + var out [][]byte + for i := offset; i < len(bs); { + ret, ni, err := ParseTagAndLength(bs, i) + if err != nil { + return nil + } + i = ni + if ret.Tag == target { + out = append(out, bs[i:i+ret.Length]) + } else if ret.Tag == asn1.TagSequence { + if ret := FindTypeList(bs, i, target); len(ret) > 0 { + out = append(out, ret...) + } + } + i += ret.Length + } + return out +} + +func FindOids(bs []byte) []asn1.ObjectIdentifier { + var oids = FindTypeList(bs, 0, asn1.TagOID) + return lo.FilterMap(oids, func(item []byte, index int) (asn1.ObjectIdentifier, bool) { + identifier, err := ParseObjectIdentifier(item) + if err != nil { + return nil, false + } + return identifier, true + }) +} + +type PublicKeyInfo struct { Raw asn1.RawContent Algorithm pkix.AlgorithmIdentifier PublicKey asn1.BitString } // ParsePKIXPublicKey x509.ParsePKIXPublicKey -func ParsePKIXPublicKey(bs []byte) ([]byte, error) { - var pki publicKeyInfo +func ParsePKIXPublicKey(bs []byte) (*PublicKeyInfo, error) { + var pki PublicKeyInfo if rest, err := asn1.Unmarshal(bs, &pki); err != nil { return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)") } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after ASN.1 of public-key") } - return pki.PublicKey.Bytes, nil + return &pki, nil } -func ParseSignatureRS(bs []byte) (r *big.Int, s *big.Int, _ error) { +func ParseSignatureRS(bs []byte) (r []byte, s []byte, _ error) { var inner cryptobyte.String input := cryptobyte.String(bs) if !input.ReadASN1(&inner, asn11.SEQUENCE) || @@ -110,15 +233,38 @@ func ParseSignatureRS(bs []byte) (r *big.Int, s *big.Int, _ error) { !inner.Empty() { return nil, nil, errors.New("invalid ASN.1") } - return r, s, nil + return padSliceLeft(r, 32), padSliceLeft(s, 32), nil } -func ParseSignatureRSSlice(bs []byte) (out [64]byte, _ error) { +func ParseSignatureRSSlice(bs []byte) ([64]byte, error) { r, s, err := ParseSignatureRS(bs) if err != nil { return [64]byte{}, err } - r.FillBytes(out[0:32]) - s.FillBytes(out[32:64]) + var out [64]byte + copy(out[0:32], r) + copy(out[32:64], s) return out, nil } + +func MarshalAsn1SignatureRS(r, s []byte) []byte { + var b cryptobyte.Builder + b.AddASN1(asn11.SEQUENCE, func(child *cryptobyte.Builder) { + child.AddASN1BigInt(new(big.Int).SetBytes(r)) + child.AddASN1BigInt(new(big.Int).SetBytes(s)) + }) + return b.BytesOrPanic() +} + +func MarshalAsn1SignatureSlice(bs [64]byte) []byte { + return MarshalAsn1SignatureRS(bs[0:32], bs[32:64]) +} + +func padSliceLeft(bs []byte, size int) []byte { + if len(bs) >= size { + return bs[:size] + } + var out = make([]byte, size) + copy(out[size-len(bs):], bs[:]) + return out +} diff --git a/xasn1/mod_test.go b/xasn1/mod_test.go new file mode 100644 index 0000000..84f8f1f --- /dev/null +++ b/xasn1/mod_test.go @@ -0,0 +1,27 @@ +package xasn1 + +import ( + "bytes" + "crypto/rand" + "github.com/samber/lo" + "testing" +) + +func TestMarshalSigRS(t *testing.T) { + var r = mustRand(31) + var s = mustRand(31) + var asn1Sig = MarshalAsn1SignatureRS(r, s) + var rr, ss, err = ParseSignatureRS(asn1Sig) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(padSliceLeft(r, 32), rr) || !bytes.Equal(padSliceLeft(s, 32), ss) { + t.Fatalf("marshal failed") + } +} + +func mustRand(size int) []byte { + var out = make([]byte, size) + lo.Must1(rand.Reader.Read(out)) + return out +} diff --git a/xed25519/pubkey.go b/xed25519/pubkey.go index 1918bd0..b379e6a 100644 --- a/xed25519/pubkey.go +++ b/xed25519/pubkey.go @@ -39,10 +39,11 @@ func ParsePubKey(serialized [32]byte) (key PublicKey, err error) { } func ParsePubKeyASN1(bs []byte) (key PublicKey, err error) { - serialized, err := xasn1.ParsePKIXPublicKey(bs) + k, err := xasn1.ParsePKIXPublicKey(bs) if err != nil { return PublicKey{}, err } + serialized := k.PublicKey.Bytes if len(serialized) != 32 { return PublicKey{}, BadFormatPublicKeyErr } diff --git a/xsecp256k1/kit.go b/xsecp256k1/kit.go index aa623cf..55a0f2d 100644 --- a/xsecp256k1/kit.go +++ b/xsecp256k1/kit.go @@ -23,10 +23,11 @@ func (k *secp256k1Kit) SignANS1(asn1Key []byte, hash []byte) (SignatureCompat, e if err != nil { return SignatureCompat{}, err } - if len(key) != 32 { + bs := key.PublicKey.Bytes + if len(bs) != 32 { return SignatureCompat{}, BadFormatPublicKeyErr } - return k.Sign([32]byte(key), hash), nil + return k.Sign([32]byte(bs), hash), nil } func (k *secp256k1Kit) VerifySignatureRS(pubKey []byte, rBig, sBig *big.Int, hash []byte) bool { diff --git a/xx509/publicKey.go b/xx509/publicKey.go new file mode 100644 index 0000000..4e37de0 --- /dev/null +++ b/xx509/publicKey.go @@ -0,0 +1,264 @@ +package xx509 + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "golang.org/x/crypto/cryptobyte" + cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" + "math/big" +) + +// RFC 5480, 2.1.1.1. Named Curve +// +// secp224r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 33 } +// +// secp256r1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-X9-62(10045) curves(3) +// prime(1) 7 } +// +// secp384r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 34 } +// +// secp521r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 35 } +// +// NB: secp256r1 is equivalent to prime256v1 +var ( + OidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} + OidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + OidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + OidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} + OidNameCurveSecp256k1 = asn1.ObjectIdentifier{1, 3, 132, 0, 10} +) + +func NamedECurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { + switch { + case oid.Equal(OidNamedCurveP224): + return elliptic.P224() + case oid.Equal(OidNamedCurveP256): + return elliptic.P256() + case oid.Equal(OidNamedCurveP384): + return elliptic.P384() + case oid.Equal(OidNamedCurveP521): + return elliptic.P521() + case oid.Equal(OidNameCurveSecp256k1): + return secp256k1.S256() + } + return nil +} + +var ( + // OidPublicKeyRSA RFC 3279, 2.3 Public Key Algorithms + // + // pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) + // rsadsi(113549) pkcs(1) 1 } + // + // rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 } + // + // id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) + // x9-57(10040) x9cm(4) 1 } + OidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + OidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + // OidPublicKeyECDSA RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters + // + // id-ecPublicKey OBJECT IDENTIFIER ::= { + // iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } + OidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + // OidPublicKeyX25519 RFC 8410, Section 3 + // + // id-X25519 OBJECT IDENTIFIER ::= { 1 3 101 110 } + // id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 } + OidPublicKeyX25519 = asn1.ObjectIdentifier{1, 3, 101, 110} + OidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} +) + +// GetPublicKeyAlgorithmFromOID returns the exposed PublicKeyAlgorithm +// identifier for public key types supported in certificates and CSRs. Marshal +// and Parse functions may support a different set of public key types. +func GetPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) x509.PublicKeyAlgorithm { + switch { + case oid.Equal(OidPublicKeyRSA): + return x509.RSA + case oid.Equal(OidPublicKeyDSA): + return x509.DSA + case oid.Equal(OidPublicKeyECDSA): + return x509.ECDSA + case oid.Equal(OidPublicKeyEd25519): + return x509.Ed25519 + } + return x509.UnknownPublicKeyAlgorithm +} + +// pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo +// in RFC 3280. +type pkixPublicKey struct { + Algo pkix.AlgorithmIdentifier + BitString asn1.BitString +} + +// pkcs1PublicKey reflects the ASN.1 structure of a PKCS #1 public key. +type pkcs1PublicKey struct { + N *big.Int + E int +} + +type publicKeyInfo struct { + Raw asn1.RawContent + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString +} + +// ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. The encoded +// public key is a SubjectPublicKeyInfo structure (see RFC 5280, Section 4.1). +// +// It returns a *[rsa.PublicKey], *[dsa.PublicKey], *[ecdsa.PublicKey], +// [ed25519.PublicKey] (not a pointer), or *[ecdh.PublicKey] (for X25519). +// More types might be supported in the future. +// +// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY". +func ParsePKIXPublicKey(derBytes []byte) (pub any, err error) { + var pki publicKeyInfo + if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil { + if _, err := asn1.Unmarshal(derBytes, &pkcs1PublicKey{}); err == nil { + return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)") + } + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ASN.1 of public-key") + } + return parsePublicKey(&pki) +} + +func MarshalPKIXPublicKeyRaw(publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier) []byte { + publicKey := pkixPublicKey{ + Algo: publicKeyAlgorithm, + BitString: asn1.BitString{ + Bytes: publicKeyBytes, + BitLength: 8 * len(publicKeyBytes), + }, + } + ret, _ := asn1.Marshal(publicKey) + return ret +} + +//asn1.ObjectIdentifier + +////// RFC 8410, Section 3 +// // // +// // // id-X25519 OBJECT IDENTIFIER ::= { 1 3 101 110 } +// // // id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 } +// // oidPublicKeyX25519 = asn1.ObjectIdentifier{1, 3, 101, 110} +// // oidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} + +func parsePublicKey(keyData *publicKeyInfo) (any, error) { + oid := keyData.Algorithm.Algorithm + params := keyData.Algorithm.Parameters + der := cryptobyte.String(keyData.PublicKey.RightAlign()) + switch { + case oid.Equal(OidPublicKeyRSA): + // RSA public keys must have a NULL in the parameters. + // See RFC 3279, Section 2.3.1. + if !bytes.Equal(params.FullBytes, asn1.NullBytes) { + return nil, errors.New("x509: RSA key missing NULL parameters") + } + + p := &pkcs1PublicKey{N: new(big.Int)} + if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: invalid RSA public key") + } + if !der.ReadASN1Integer(p.N) { + return nil, errors.New("x509: invalid RSA modulus") + } + if !der.ReadASN1Integer(&p.E) { + return nil, errors.New("x509: invalid RSA public exponent") + } + + if p.N.Sign() <= 0 { + return nil, errors.New("x509: RSA modulus is not a positive number") + } + if p.E <= 0 { + return nil, errors.New("x509: RSA public exponent is not a positive number") + } + + pub := &rsa.PublicKey{ + E: p.E, + N: p.N, + } + return pub, nil + case oid.Equal(OidPublicKeyECDSA): + paramsDer := cryptobyte.String(params.FullBytes) + namedCurveOID := new(asn1.ObjectIdentifier) + if !paramsDer.ReadASN1ObjectIdentifier(namedCurveOID) { + return nil, errors.New("x509: invalid ECDSA parameters") + } + namedCurve := NamedECurveFromOID(*namedCurveOID) + if namedCurve == nil { + return nil, errors.New("x509: unsupported elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, der) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil + case oid.Equal(OidPublicKeyEd25519): + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(params.FullBytes) != 0 { + return nil, errors.New("x509: Ed25519 key encoded with illegal parameters") + } + if len(der) != ed25519.PublicKeySize { + return nil, errors.New("x509: wrong Ed25519 public key size") + } + return ed25519.PublicKey(der), nil + case oid.Equal(OidPublicKeyX25519): + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(params.FullBytes) != 0 { + return nil, errors.New("x509: X25519 key encoded with illegal parameters") + } + return ecdh.X25519().NewPublicKey(der) + case oid.Equal(OidPublicKeyDSA): + y := new(big.Int) + if !der.ReadASN1Integer(y) { + return nil, errors.New("x509: invalid DSA public key") + } + pub := &dsa.PublicKey{ + Y: y, + Parameters: dsa.Parameters{ + P: new(big.Int), + Q: new(big.Int), + G: new(big.Int), + }, + } + paramsDer := cryptobyte.String(params.FullBytes) + if !paramsDer.ReadASN1(¶msDer, cryptobyte_asn1.SEQUENCE) || + !paramsDer.ReadASN1Integer(pub.Parameters.P) || + !paramsDer.ReadASN1Integer(pub.Parameters.Q) || + !paramsDer.ReadASN1Integer(pub.Parameters.G) { + return nil, errors.New("x509: invalid DSA parameters") + } + if pub.Y.Sign() <= 0 || pub.Parameters.P.Sign() <= 0 || + pub.Parameters.Q.Sign() <= 0 || pub.Parameters.G.Sign() <= 0 { + return nil, errors.New("x509: zero or negative DSA parameter") + } + return pub, nil + default: + return nil, errors.New("x509: unknown public key algorithm") + } +}