|
| 1 | +package kgc |
| 2 | + |
| 3 | +import ( |
| 4 | + "crypto/rand" |
| 5 | + "encoding/binary" |
| 6 | + "fmt" |
| 7 | + "math/big" |
| 8 | + "github.com/emmansun/gmsm/sm2" |
| 9 | + "github.com/emmansun/gmsm/sm3" |
| 10 | + "github.com/BurntSushi/toml" |
| 11 | + "os" |
| 12 | + "path/filepath" |
| 13 | + "github.com/OpenNHP/opennhp/kgc/user" |
| 14 | +) |
| 15 | + |
| 16 | +var( |
| 17 | + id = user.ID |
| 18 | + curve = sm2.P256() |
| 19 | + N = curve.Params().N |
| 20 | + Gx = curve.Params().Gx |
| 21 | + Gy = curve.Params().Gy |
| 22 | + IdA = []byte(id) |
| 23 | + EntlA = len(IdA) * 8 |
| 24 | + Ms *big.Int |
| 25 | + PpubX *big.Int |
| 26 | + PpubY *big.Int |
| 27 | + curveParams CurveParams |
| 28 | + A, B *big.Int |
| 29 | + WAx, WAy, W *big.Int |
| 30 | + HA []byte |
| 31 | + L *big.Int |
| 32 | + TA *big.Int |
| 33 | +) |
| 34 | + |
| 35 | +// A structure for storing configuration |
| 36 | +type CurveParams struct { |
| 37 | + A string `toml:"a"` |
| 38 | + B string `toml:"b"` |
| 39 | +} |
| 40 | + |
| 41 | +//InitConfig loads the configuration and initializes global variables |
| 42 | +func InitConfig() error { |
| 43 | + // Get the current working directory path |
| 44 | + wd, err := os.Getwd() |
| 45 | + if err != nil { |
| 46 | + return fmt.Errorf("error getting current directory: %v", err) |
| 47 | + } |
| 48 | + |
| 49 | + // Path to splice TOML files |
| 50 | + tomlFilePath := filepath.Join(wd, "kgc", "main", "etc", "Curve.toml") |
| 51 | + |
| 52 | + // Read and parse TOML files |
| 53 | + _, err = toml.DecodeFile(tomlFilePath, &curveParams) |
| 54 | + if err != nil { |
| 55 | + return fmt.Errorf("error loading TOML file: %v", err) |
| 56 | + } |
| 57 | + |
| 58 | + // Convert a and b from strings in TOML file to big.Int type |
| 59 | + A = new(big.Int) |
| 60 | + A.SetString(curveParams.A, 16) |
| 61 | + B = new(big.Int) |
| 62 | + B.SetString(curveParams.B, 16) |
| 63 | + return nil |
| 64 | +} |
| 65 | + |
| 66 | +func GetA() *big.Int { |
| 67 | + return A |
| 68 | +} |
| 69 | + |
| 70 | +func GetB() *big.Int { |
| 71 | + return B |
| 72 | +} |
| 73 | + |
| 74 | +// GenerateMasterKeyPairSM2,Generate the system's master private key ms and master public key Ppub |
| 75 | +func GenerateMasterKeyPairSM2() (*big.Int, *big.Int, error) { |
| 76 | + curve := sm2.P256() |
| 77 | + ms, err := rand.Int(rand.Reader, curve.Params().N) |
| 78 | + if err != nil { |
| 79 | + return nil, nil, fmt.Errorf("failed to generate system master private key ms: %v", err) |
| 80 | + } |
| 81 | + if ms.Cmp(big.NewInt(0)) == 0 { |
| 82 | + ms, err = rand.Int(rand.Reader, curve.Params().N) |
| 83 | + if err != nil { |
| 84 | + return nil, nil, fmt.Errorf("regeneration of system master private key ms failed: %v", err) |
| 85 | + } |
| 86 | + } |
| 87 | + Ppubx, Ppuby := curve.ScalarBaseMult(ms.Bytes()) |
| 88 | + Ms = ms |
| 89 | + PpubX = Ppubx |
| 90 | + PpubY = Ppuby |
| 91 | + return Ppubx, Ppuby, nil |
| 92 | +} |
| 93 | + |
| 94 | +// GenerateWA,Calculate WA = [w]G + UA |
| 95 | +func GenerateWA(UAx, UAy *big.Int) (*big.Int, *big.Int, *big.Int, error) { |
| 96 | + curve := sm2.P256() |
| 97 | + // Generate a random number w in the range [1, n-1] |
| 98 | + w, err := rand.Int(rand.Reader, curve.Params().N) |
| 99 | + if err != nil { |
| 100 | + return nil, nil, nil, fmt.Errorf("failed to generate random number w: %v", err) |
| 101 | + } |
| 102 | + |
| 103 | + // Make sure w is not 0 |
| 104 | + if w.Cmp(big.NewInt(0)) == 0 { |
| 105 | + w, err = rand.Int(rand.Reader, curve.Params().N) |
| 106 | + if err != nil { |
| 107 | + return nil, nil, nil, fmt.Errorf("failed to regenerate random number w: %v", err) |
| 108 | + } |
| 109 | + } |
| 110 | + Wx, Wy := curve.ScalarBaseMult(w.Bytes()) |
| 111 | + wAx, wAy := curve.Add(Wx, Wy, UAx, UAy) |
| 112 | + WAx = wAx |
| 113 | + WAy = wAy |
| 114 | + W = w |
| 115 | + return WAx, WAy, w, nil |
| 116 | +} |
| 117 | + |
| 118 | +// Calculate HA = H256(entlA || idA || a || b || xG || yG || xPub || yPub) |
| 119 | +func CalculateHA(entlA int, idA []byte, a, b, xG, yG, xPub, yPub *big.Int) ([]byte,error) { |
| 120 | + if a == nil || b == nil || xG == nil || yG == nil || xPub == nil || yPub == nil { |
| 121 | + return nil, fmt.Errorf("one or more big.Int parameters passed in were nil") |
| 122 | + } |
| 123 | + entlABytes := make([]byte, 2) |
| 124 | + binary.BigEndian.PutUint16(entlABytes, uint16(entlA)) |
| 125 | + data := append(entlABytes, idA...) |
| 126 | + data = append(data, a.Bytes()...) |
| 127 | + data = append(data, b.Bytes()...) |
| 128 | + data = append(data, xG.Bytes()...) |
| 129 | + data = append(data, yG.Bytes()...) |
| 130 | + data = append(data, xPub.Bytes()...) |
| 131 | + data = append(data, yPub.Bytes()...) |
| 132 | + hash := sm3.New() |
| 133 | + hash.Write(data) |
| 134 | + ha := hash.Sum(nil) |
| 135 | + HA = ha |
| 136 | + return HA,nil |
| 137 | +} |
| 138 | + |
| 139 | +// ComputeL l = H256(xWA‖yWA‖HA) mod n |
| 140 | +func ComputeL(xWA, yWA *big.Int, HA []byte, n *big.Int) (*big.Int, error) { |
| 141 | + xBits := intToBitString(xWA) |
| 142 | + yBits := intToBitString(yWA) |
| 143 | + hashData := append(xBits, yBits...) |
| 144 | + hashData = append(hashData, HA...) |
| 145 | + hash := sm3.Sum(hashData) |
| 146 | + l := new(big.Int).SetBytes(hash[:]) |
| 147 | + l.Mod(l, n) |
| 148 | + if l.Cmp(big.NewInt(0)) < 0 { |
| 149 | + return nil, fmt.Errorf("the calculated result l is a negative number") |
| 150 | + } |
| 151 | + k := (n.BitLen() + 7) / 8 |
| 152 | + lBytes := intToBytes(l, k) |
| 153 | + lInteger := new(big.Int).SetBytes(lBytes) |
| 154 | + L = lInteger |
| 155 | + return L, nil |
| 156 | +} |
| 157 | + |
| 158 | +// intToBitString |
| 159 | +func intToBitString(x *big.Int) []byte { |
| 160 | + bitLen := x.BitLen() |
| 161 | + byteLen := (bitLen + 7) / 8 |
| 162 | + bitString := make([]byte, byteLen) |
| 163 | + xBytes := x.Bytes() |
| 164 | + copy(bitString[byteLen-len(xBytes):], xBytes) |
| 165 | + return bitString |
| 166 | +} |
| 167 | + |
| 168 | +// intToBytes |
| 169 | +func intToBytes(x *big.Int, k int) []byte { |
| 170 | + m := make([]byte, k) |
| 171 | + xBytes := x.Bytes() |
| 172 | + copy(m[k-len(xBytes):], xBytes) |
| 173 | + return m |
| 174 | +} |
| 175 | + |
| 176 | + |
| 177 | +//Calculate tA= w + (l * ms) |
| 178 | +func ComputeTA(w, lInteger, ms, n *big.Int) *big.Int { |
| 179 | + tA := new(big.Int).Set(w) |
| 180 | + lMod := new(big.Int).Mod(lInteger, n) |
| 181 | + msMod := new(big.Int).Mod(ms, n) |
| 182 | + lMulMs := new(big.Int).Mul(lMod, msMod) |
| 183 | + lMulMs.Mod(lMulMs, n) |
| 184 | + tA.Add(tA, lMulMs) |
| 185 | + tA.Mod(tA, n) |
| 186 | + TA = tA |
| 187 | + return TA |
| 188 | +} |
| 189 | + |
| 190 | + |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | + |
0 commit comments