diff --git a/README.md b/README.md index 05475b5..d2ccaf3 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ package main import ( "log" + "math/rand" "github.com/lucasmenendez/gosss" ) @@ -37,29 +38,36 @@ func main() { // create a configuration with 8 shares and 7 minimum shares to recover the // message config := &gosss.Config{ - Shares: 8, - Min: 7, + Shares: 4, + Min: 3, } // hide a message with the defined configuration - shares, err := gosss.HideMessage("secret", config) + msg := "688641b753f1c97526d6a767058a80fd6c6519f5bdb0a08098986b0478c8502b" + log.Printf("message to hide: %s", msg) + totalShares, err := gosss.HideMessage([]byte(msg), config) if err != nil { log.Fatalf("error hiding message: %v", err) } // print every share and exclude one share to test the recovery - excluded := 3 - requiredShares := []string{} - for i, s := range shares { - log.Printf("share: %s", s) - if i != excluded { - requiredShares = append(requiredShares, s) + requiredShares := [][]string{} + for _, secretShares := range totalShares { + log.Printf("shares: %v", secretShares) + // choose a random share to exclude + index := rand.Intn(len(secretShares)) + shares := []string{} + for i, share := range secretShares { + if i == index { + continue + } + shares = append(shares, share) } + requiredShares = append(requiredShares, shares) } // recover the message with the required shares message, err := gosss.RecoverMessage(requiredShares, nil) if err != nil { log.Fatalf("error recovering message: %v", err) } - log.Printf("recovered message: %s", message) + log.Printf("recovered message: %s", string(message)) } - ``` \ No newline at end of file diff --git a/config.go b/config.go index c548502..27b93cb 100644 --- a/config.go +++ b/config.go @@ -1,9 +1,6 @@ package gosss -import ( - "fmt" - "math/big" -) +import "math/big" // 12th Mersenne Prime (2^127 - 1) var DefaultPrime *big.Int = new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(127), nil), big.NewInt(1)) @@ -40,10 +37,10 @@ func (c *Config) prepare(op operation) error { // and the minimum number of shares is greater than 1 and lower than // the number of shares if c.Shares <= 0 || c.Shares > maxShares { - return fmt.Errorf("number of shares must be between 1 and %d", maxShares) + return ErrConfigShares } if c.Min <= 1 || c.Min >= c.Shares { - return fmt.Errorf("minimum number of shares must be between 2 and %d", c.Shares-1) + return ErrConfigMin } case recoverOp: // for recover a message no checks are needed unless the prime number is @@ -51,7 +48,7 @@ func (c *Config) prepare(op operation) error { // it is needed break default: - return fmt.Errorf("unknown operation") + return ErrConfigOp } // if the prime number is not defined it will use the default prime number if c.Prime == nil { @@ -59,3 +56,10 @@ func (c *Config) prepare(op operation) error { } return nil } + +// maxSecretPartSize returns the maximum size of the secret part that can be +// hidden in a share, it is the size of the prime number in bytes minus 1, to +// ensure the secret part is smaller than the prime number. +func (c *Config) maxSecretPartSize() int { + return len(c.Prime.Bytes()) - 1 +} diff --git a/encode.go b/encode.go new file mode 100644 index 0000000..3ed9883 --- /dev/null +++ b/encode.go @@ -0,0 +1,69 @@ +package gosss + +import ( + "encoding/hex" + "fmt" + "math/big" +) + +// shareToStr converts a big.Int to a string. It uses the bytes of the big.Int +func shareToStr(index, share *big.Int) (string, error) { + if index.Cmp(big.NewInt(255)) > 0 || index.Cmp(big.NewInt(0)) < 0 { + return "", ErrShareIndex + } + bShare := share.Bytes() + // encode the index in a byte and append it at the end of the share + bIndex := index.Bytes() + if len(bIndex) == 0 { + bIndex = []byte{0} + } + fullShare := append(bShare, bIndex[0]) + return hex.EncodeToString(fullShare), nil +} + +// strToShare converts a string to a big.Int. It uses the bytes of the string +func strToShare(s string) (*big.Int, *big.Int, error) { + b, err := hex.DecodeString(s) + if err != nil { + return nil, nil, fmt.Errorf("%w: %w", ErrDecodeShare, err) + } + bIndex := b[len(b)-1] + bShare := b[:len(b)-1] + return new(big.Int).SetBytes([]byte{bIndex}), new(big.Int).SetBytes(bShare), nil +} + +// encodeShares function converts the x and y coordinates of the shares to +// strings. It returns the shares as strings. It uses the shareToStr function to +// encode the shares. It returns an error if the shares cannot be encoded. +func encodeShares(xs, ys []*big.Int) ([]string, error) { + if len(xs) == 0 || len(ys) == 0 || len(xs) != len(ys) { + return nil, ErrInvalidShares + } + // convert the shares to strings and append them to the result + shares := []string{} + for i := 0; i < len(xs); i++ { + share, err := shareToStr(xs[i], ys[i]) + if err != nil { + return nil, err + } + shares = append(shares, share) + } + return shares, nil +} + +// decodeShares function converts the strings of the shares to x and y +// coordinates of the shares. It uses the strToShare function to decode the +// shares. It returns an error if the shares cannot be decoded. +func decodeShares(shares []string) ([]*big.Int, []*big.Int, error) { + xs := []*big.Int{} + ys := []*big.Int{} + for _, strShare := range shares { + index, share, err := strToShare(strShare) + if err != nil { + return nil, nil, err + } + xs = append(xs, index) + ys = append(ys, share) + } + return xs, ys, nil +} diff --git a/encode_test.go b/encode_test.go new file mode 100644 index 0000000..8ea9a78 --- /dev/null +++ b/encode_test.go @@ -0,0 +1,72 @@ +package gosss + +import ( + "math/big" + "testing" +) + +func Test_shareToStrStrToShare(t *testing.T) { + // generate 10 random big.Int and convert them to string + for i := 0; i < 10; i++ { + idx := big.NewInt(int64(i)) + rand, err := randBigInt() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + shareStr, err := shareToStr(idx, rand) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + index, shareBack, err := strToShare(shareStr) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if index.Cmp(idx) != 0 { + t.Errorf("unexpected index: %d", index) + } + if rand.Cmp(shareBack) != 0 { + t.Errorf("unexpected share: %s", shareStr) + } + } +} + +func Test_encodeDecodeShares(t *testing.T) { + xs := []*big.Int{ + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + big.NewInt(4), + } + ys := []*big.Int{ + big.NewInt(100), + big.NewInt(200), + big.NewInt(300), + big.NewInt(400), + } + encodedShares, err := encodeShares(xs, ys) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + decodedXs, decodedYs, err := decodeShares(encodedShares) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if len(xs) != len(decodedXs) || len(ys) != len(decodedYs) { + t.Errorf("unexpected shares length") + return + } + for i := 0; i < len(xs); i++ { + if xs[i].Cmp(decodedXs[i]) != 0 { + t.Errorf("unexpected x coordinate") + return + } + if ys[i].Cmp(decodedYs[i]) != 0 { + t.Errorf("unexpected y coordinate") + return + } + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..ee0d2c7 --- /dev/null +++ b/errors.go @@ -0,0 +1,19 @@ +package gosss + +import "fmt" + +var ( + // config + ErrConfigShares = fmt.Errorf("wrong number of shares, it must be between 1 and the maximum number of shares") + ErrConfigMin = fmt.Errorf("wrong minimum number of shares, it must be between 2 and the number of shares minus 1") + ErrConfigOp = fmt.Errorf("unknown operation") + // encode + ErrShareIndex = fmt.Errorf("the index must fit in a byte (0-255)") + ErrDecodeShare = fmt.Errorf("error decoding share") + ErrInvalidShares = fmt.Errorf("invalid shares") + // math + ErrReadingRandom = fmt.Errorf("error reading random number") + // sss + ErrRequiredConfig = fmt.Errorf("configuration is required") + ErrEncodeMessage = fmt.Errorf("error encoding message") +) diff --git a/helpers_test.go b/helpers_test.go deleted file mode 100644 index e7332ba..0000000 --- a/helpers_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package gosss - -import ( - "math/big" - "testing" -) - -func Test_encodeDecodeMsg(t *testing.T) { - encodedPrivateMsg := msgToBigInt(examplePrivateMessage) - if encodedPrivateMsg == nil { - t.Errorf("unexpected nil encoded string") - return - } - decodedPrivateMsg := bigIntToMsg(encodedPrivateMsg) - if examplePrivateMessage != decodedPrivateMsg { - t.Errorf("unexpected decoded string: %s", decodedPrivateMsg) - } -} - -func Test_shareToStrStrToShare(t *testing.T) { - // generate 10 random big.Int and convert them to string - for i := 0; i < 10; i++ { - idx := big.NewInt(int64(i)) - rand, err := randBigInt() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - shareStr, err := shareToStr(idx, rand) - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - index, shareBack, err := strToShare(shareStr) - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - if index.Cmp(idx) != 0 { - t.Errorf("unexpected index: %d", index) - } - if rand.Cmp(shareBack) != 0 { - t.Errorf("unexpected share: %s", shareStr) - } - } - -} diff --git a/helpers.go b/math.go similarity index 65% rename from helpers.go rename to math.go index 01b9655..5a24ca2 100644 --- a/helpers.go +++ b/math.go @@ -3,12 +3,45 @@ package gosss import ( "crypto/rand" "encoding/binary" - "encoding/hex" "fmt" "math" "math/big" ) +// randBigInt generates a random big.Int number. It uses the crypto/rand package +// to generate the random number. It returns an error if the random number +// cannot be generated. +func randBigInt() (*big.Int, error) { + var b [8]byte + _, err := rand.Read(b[:]) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrReadingRandom, err) + } + // convert the bytes to an int64 and ensure it is non-negative + randomInt := int64(binary.BigEndian.Uint64(b[:])) & (1<<63 - 1) + // scale down the random int to the range [0, max) + return big.NewInt(randomInt % math.MaxInt64), nil +} + +// calcCoeffs function generates the coefficients for the polynomial. It takes +// the secret and the number of coefficients to generate. It returns the +// coefficients as a list of big.Int. It returns an error if the coefficients +// cannot be generated. The secret is the first coefficient of the polynomial, +// the rest of the coefficients are random numbers. +func calcCoeffs(secret *big.Int, k int) ([]*big.Int, error) { + // calculate k-1 random coefficients + randCoeffs := make([]*big.Int, k-1) + for i := 0; i < len(randCoeffs); i++ { + randCoeff, err := randBigInt() + if err != nil { + return nil, err + } + randCoeffs[i] = randCoeff + } + // include secret as the first coefficient and return the coefficients + return append([]*big.Int{secret}, randCoeffs...), nil +} + // solvePolynomial solves a polynomial with coefficients coeffs for x // and returns the result. It follows the formula: // f(x) = a0 + a1*x + a2*x^2 + ... + an*x^n @@ -25,6 +58,26 @@ func solvePolynomial(coeffs []*big.Int, x, prime *big.Int) *big.Int { return accum } +// calcShares function calculates the shares of the polynomial for the given +// coefficients and the number of shares to generate. It returns the x and y +// coordinates of the shares. The x coordinates are the index of the share and +// the y coordinates are the share itself. It uses the prime number to perform +// the modular operation in the finite field. It returns an error if the shares +// cannot be calculated. It skips the x = 0 coordinate because it is the secret +// itself. +func calcShares(coeffs []*big.Int, shares int, prime *big.Int) ([]*big.Int, []*big.Int) { + // calculate shares solving the polynomial for x = {1, shares}, x = 0 is the + // secret + var xs, yx []*big.Int + for i := 0; i < shares; i++ { + x := big.NewInt(int64(i + 1)) + y := solvePolynomial(coeffs, x, prime) + xs = append(xs, x) + yx = append(yx, y) + } + return xs, yx +} + // lagrangeInterpolation calculates the Lagrange interpolation for the given // points and a specific x value. The formula for Lagrange interpolation over a // finite field defined by a prime is: @@ -80,56 +133,3 @@ func lagrangeInterpolation(xCoords, yCoords []*big.Int, prime, _x *big.Int) *big } return result } - -// msgToBigInt converts a string to a big.Int. It uses the bytes of the string -// to create the big.Int. -func msgToBigInt(s string) *big.Int { - return new(big.Int).SetBytes([]byte(s)) -} - -// bigIntToMsg converts a big.Int to a string. It uses the bytes of the big.Int -// to create the string. -func bigIntToMsg(i *big.Int) string { - return string(i.Bytes()) -} - -// shareToStr converts a big.Int to a string. It uses the bytes of the big.Int -func shareToStr(index, share *big.Int) (string, error) { - if index.Cmp(big.NewInt(255)) > 0 || index.Cmp(big.NewInt(0)) < 0 { - return "", fmt.Errorf("the index must fit in a byte (0-255)") - } - bShare := share.Bytes() - // encode the index in a byte and append it at the end of the share - bIndex := index.Bytes() - if len(bIndex) == 0 { - bIndex = []byte{0} - } - fullShare := append(bShare, bIndex[0]) - return hex.EncodeToString(fullShare), nil -} - -// strToShare converts a string to a big.Int. It uses the bytes of the string -func strToShare(s string) (*big.Int, *big.Int, error) { - b, err := hex.DecodeString(s) - if err != nil { - return nil, nil, fmt.Errorf("error decoding share: %v", err) - } - bIndex := b[len(b)-1] - bShare := b[:len(b)-1] - return new(big.Int).SetBytes([]byte{bIndex}), new(big.Int).SetBytes(bShare), nil -} - -// randBigInt generates a random big.Int number. It uses the crypto/rand package -// to generate the random number. It returns an error if the random number -// cannot be generated. -func randBigInt() (*big.Int, error) { - var b [8]byte - _, err := rand.Read(b[:]) - if err != nil { - return nil, err - } - // convert the bytes to an int64 and ensure it is non-negative - randomInt := int64(binary.BigEndian.Uint64(b[:])) & (1<<63 - 1) - // scale down the random int to the range [0, max) - return big.NewInt(randomInt % math.MaxInt64), nil -} diff --git a/math_test.go b/math_test.go new file mode 100644 index 0000000..427d239 --- /dev/null +++ b/math_test.go @@ -0,0 +1,193 @@ +package gosss + +import ( + "math/big" + "testing" +) + +func Test_randBigInt(t *testing.T) { + generatedRands := make(map[int64]bool) + for i := 0; i < 100000; i++ { + rand, err := randBigInt() + if err != nil { + t.Fatalf("error generating random number: %v", err) + return + } + if _, ok := generatedRands[rand.Int64()]; ok { + t.Fatalf("duplicated random number") + return + } + generatedRands[rand.Int64()] = true + } +} + +func Test_calcCoeffs(t *testing.T) { + secret := big.NewInt(123456789) + coeffs, err := calcCoeffs(secret, 5) + if err != nil { + t.Fatalf("error calculating coefficients: %v", err) + return + } + if len(coeffs) != 5 { + t.Fatalf("invalid number of coefficients") + return + } + if coeffs[0].Cmp(secret) != 0 { + t.Fatalf("invalid secret coefficient") + return + } + checkedCoeffs := make(map[int64]bool) + for i := 1; i < len(coeffs); i++ { + if _, ok := checkedCoeffs[coeffs[i].Int64()]; ok { + t.Fatalf("duplicated coefficient") + return + } + checkedCoeffs[coeffs[i].Int64()] = true + } +} + +func Test_solvePolynomial(t *testing.T) { + // f(x) = 1 + 2x + 3x^2 + 4x^3 + basicCoeffs := []*big.Int{ + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + big.NewInt(4), + } + // x = 2, prime = 5 + basicX := big.NewInt(2) + basicPrime := big.NewInt(5) + // f(2) = 1 + 4 + 12 + 32 = 49 % 5 = 4 + basicExpected := big.NewInt(4) + basicResult := solvePolynomial(basicCoeffs, basicX, basicPrime) + if basicResult.Cmp(basicExpected) != 0 { + t.Errorf("Simple polynomial failed, expected %v, got %v", basicExpected, basicResult) + } + + // f(x) = 0 + zeroCoeffs := []*big.Int{ + big.NewInt(0), + big.NewInt(0), + big.NewInt(0), + big.NewInt(0), + } + // x = 5, prime = 7 + zeroX := big.NewInt(5) + zeroPrime := big.NewInt(7) + // f(5) = 0 + zeroExpected := big.NewInt(0) + zeroResult := solvePolynomial(zeroCoeffs, zeroX, zeroPrime) + if zeroResult.Cmp(zeroExpected) != 0 { + t.Errorf("Zero polynomial failed, expected %v, got %v", zeroExpected, zeroResult) + } + + // f(x) = -1 - 2x - 3x^2 - 4x^3 + negativeCoeffs := []*big.Int{ + big.NewInt(-1), + big.NewInt(-2), + big.NewInt(-3), + big.NewInt(-4), + } + // x = -2, prime = 5 + negativeX := big.NewInt(-2) + negativePrime := big.NewInt(5) + // f(-2) = -1 + 4 - 12 + 32 = 23 % 5 = 3 + negativeExpected := big.NewInt(3) + negativeResult := solvePolynomial(negativeCoeffs, negativeX, negativePrime) + if negativeResult.Cmp(negativeExpected) != 0 { + t.Errorf("Negative polynomial failed, expected %v, got %v", negativeExpected, negativeResult) + } + + // f(x) = 1 + 2x + 3x^2 + 4x^3 + 5x^4 + 6x^5 + 7x^6 + 8x^7 + highDegreeCoeffs := []*big.Int{ + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + big.NewInt(4), + big.NewInt(5), + big.NewInt(6), + big.NewInt(7), + big.NewInt(8), + } + highDegreeX := 2 + highDegreePrime := 13 + // f(2) = 1 + 4 + 12 + 32 + 80 + 192 + 448 + 1024 = 1793 % 13 = 12 + highDegreeExpected := big.NewInt(12) + highDegreeResult := solvePolynomial(highDegreeCoeffs, big.NewInt(int64(highDegreeX)), big.NewInt(int64(highDegreePrime))) + if highDegreeResult.Cmp(highDegreeExpected) != 0 { + t.Errorf("High degree polynomial failed, expected %v, got %v", highDegreeExpected, highDegreeResult) + } +} + +func Test_calcShares(t *testing.T) { + prime := big.NewInt(5) + // f(x) = 5 + x + 2x^2 + 3x^3 + coeffs := []*big.Int{ + big.NewInt(5), + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + } + k := len(coeffs) + xs, ys := calcShares(coeffs, k, prime) + if len(xs) != k { + t.Fatalf("invalid number of x coordinates") + return + } + if len(ys) != k { + t.Fatalf("invalid number of y coordinates") + return + } + expectedYs := []*big.Int{ + big.NewInt(1), // x = 1; f(1) = 5 + 1 + 2 + 3 = 11 % 5 = 1 + big.NewInt(4), // x = 2; f(2) = 5 + 2 + 8 + 24 = 39 % 5 = 4 + big.NewInt(2), // x = 3; f(3) = 5 + 3 + 18 + 81 = 107 % 5 = 2 + big.NewInt(3), // x = 4; f(4) = 5 + 4 + 32 + 192 = 233 % 5 = 3 + } + for i := 0; i < k; i++ { + if xs[i].Cmp(big.NewInt(int64(i+1))) != 0 { + t.Fatalf("invalid x coordinate, expected %v, got %v", i+1, xs[i]) + return + } + if ys[i].Cmp(expectedYs[i]) != 0 { + t.Fatalf("invalid y coordinate (%d), expected %v, got %v", i, expectedYs[i], ys[i]) + return + } + } +} + +func Test_lagrangeInterpolation(t *testing.T) { + prime := big.NewInt(5) + // f(x) = (6 + x + 2x^2 + 3x^3) % 5 + coeffs := []*big.Int{ + big.NewInt(6), + big.NewInt(1), + big.NewInt(2), + big.NewInt(3), + } + xs, ys := calcShares(coeffs, len(coeffs), prime) + // x = 0; f(0) = 6 % 5 = 1 + x0 := big.NewInt(0) + y0 := big.NewInt(1) + result0 := lagrangeInterpolation(xs, ys, prime, x0) + if result0.Cmp(y0) != 0 { + t.Errorf("x = 0 failed, expected %v, got %v", y0, result0) + return + } + + // x = 3; f(3) = 6 + 3 + 18 + 81 = 108 % 5 = 3 + x3 := big.NewInt(3) + y3 := big.NewInt(3) + result3 := lagrangeInterpolation(xs, ys, prime, x3) + if result3.Cmp(y3) != 0 { + t.Errorf("x = 3 failed, expected %v, got %v", y3, result3) + return + } + // x = 4; f(4) = 6 + 4 + 32 + 192 = 234 % 5 = 4 + x4 := big.NewInt(4) + y4 := big.NewInt(4) + result4 := lagrangeInterpolation(xs, ys, prime, x4) + if result4.Cmp(y4) != 0 { + t.Errorf("x = 4 failed, expected %v, got %v", y4, result4) + } +} diff --git a/message.go b/message.go new file mode 100644 index 0000000..bb111e2 --- /dev/null +++ b/message.go @@ -0,0 +1,34 @@ +package gosss + +import "math/big" + +// encodeMessage function splits a message into parts of a given size and +// converts them to big.Int. It returns an error if the message cannot be +// encoded into a big.Int. If the given message is smaller than the part size, +// it returns a single part. +func encodeMessage(message []byte, partSize int) []*big.Int { + if len(message) <= partSize { + return []*big.Int{new(big.Int).SetBytes(message)} + } + var parts []*big.Int + for i := 0; i < len(message); i += partSize { + end := i + partSize + if end > len(message) { + end = len(message) + } + parts = append(parts, new(big.Int).SetBytes(message[i:end])) + } + return parts +} + +// decodeMessage function converts the parts of a message to a single string. +// It returns the decoded message. It uses the bytes of the big.Int to decode +// the message, appending them to a single byte slice and converting it to a +// string. +func decodeMessage(parts []*big.Int) []byte { + var bMessage []byte + for _, part := range parts { + bMessage = append(bMessage, part.Bytes()...) + } + return bMessage +} diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..07ba8f5 --- /dev/null +++ b/message_test.go @@ -0,0 +1,33 @@ +package gosss + +import ( + "bytes" + "math/big" + "testing" +) + +func Test_encodeDecodeMessage(t *testing.T) { + var maxPartSize = len(DefaultPrime.Bytes()) - 1 + var inputMessage = []byte("688641b753f1c97526d6a767058a80fd6c6519f5bdb0a08098986b0478c8502b") + var expectedParts = []*big.Int{ + new(big.Int).SetBytes([]byte("688641b753f1c97")), + new(big.Int).SetBytes([]byte("526d6a767058a80")), + new(big.Int).SetBytes([]byte("fd6c6519f5bdb0a")), + new(big.Int).SetBytes([]byte("08098986b0478c8")), + new(big.Int).SetBytes([]byte("502b")), + } + parts := encodeMessage(inputMessage, maxPartSize) + if len(parts) != len(expectedParts) { + t.Errorf("Expected %d parts but got %d", len(expectedParts), len(parts)) + return + } + for i, part := range parts { + if part.Cmp(expectedParts[i]) != 0 { + t.Errorf("Expected part %d to be %d but got %d", i, expectedParts[i], part) + return + } + } + if decodedMessage := decodeMessage(parts); !bytes.Equal(inputMessage, decodedMessage) { + t.Errorf("Expected %s but got %s", inputMessage, decodedMessage) + } +} diff --git a/sss.go b/sss.go index 005ac0e..f953a2d 100644 --- a/sss.go +++ b/sss.go @@ -1,9 +1,6 @@ package gosss -import ( - "fmt" - "math/big" -) +import "math/big" // HideMessage generates the shares of the message using the Shamir Secret // Sharing algorithm. It returns the shares as strings. The message is encoded @@ -12,46 +9,40 @@ import ( // configuration provided in the Config struct, if the prime number is not // defined it uses the 12th Mersenne Prime (2^127 - 1) as default. It returns // an error if the message cannot be encoded. -func HideMessage(message string, conf *Config) ([]string, error) { +func HideMessage(message []byte, conf *Config) ([][]string, error) { // the hide operation needs the minimum number of shares and the total // number of shares, so if the configuration is not provided, return an // error if conf == nil { - return nil, fmt.Errorf("configuration is required") + return nil, ErrRequiredConfig } // prepare the configuration to hide the message if err := conf.prepare(hideOp); err != nil { return nil, err } - // encode message to big.Int - secret := msgToBigInt(message) - if secret == nil { - return nil, fmt.Errorf("error encoding message") + // split the message to a list of big.Int to be used as shamir secrets + secrets := encodeMessage(message, conf.maxSecretPartSize()) + if len(secrets) == 0 { + return nil, ErrEncodeMessage } - // calculate k-1 random coefficients (k = min) - randCoeffs := make([]*big.Int, conf.Min-1) - for i := 0; i < len(randCoeffs); i++ { - randCoeff, err := randBigInt() + // generate the shares for each secret and return them encoded as strings + shares := [][]string{} + for _, secret := range secrets { + // calculate random coefficients for the polynomial + coeffs, err := calcCoeffs(secret, conf.Min) if err != nil { return nil, err } - randCoeffs[i] = randCoeff - } - // include secret as the first coefficient - coeffs := append([]*big.Int{secret}, randCoeffs...) - // calculate shares solving the polynomial for x = {1, shares}, x = 0 is the - // secret - totalShares := make([]string, conf.Shares) - for i := 0; i < conf.Shares; i++ { - x := big.NewInt(int64(i + 1)) - y := solvePolynomial(coeffs, x, conf.Prime) - share, err := shareToStr(x, y) + // calculate the shares with the polynomial and the prime number + xs, yx := calcShares(coeffs, conf.Shares, conf.Prime) + // convert the shares to strings and append them to the result + secretShares, err := encodeShares(xs, yx) if err != nil { return nil, err } - totalShares[i] = share + shares = append(shares, secretShares) } - return totalShares, nil + return shares, nil } // RecoverMessage recovers the message from the shares using the Shamir Secret @@ -62,7 +53,7 @@ func HideMessage(message string, conf *Config) ([]string, error) { // include the index of the share and the share itself, so the order of the // provided shares does not matter. It decodes the points of the polynomial from // the shares and calculates the Lagrange interpolation to recover the secret. -func RecoverMessage(shares []string, conf *Config) (string, error) { +func RecoverMessage(shares [][]string, conf *Config) ([]byte, error) { // the recover operation does not need the minimum number of shares or the // total number of shares, so if the configuration is not provided, create a // empty configuration before prepare the it. @@ -71,19 +62,21 @@ func RecoverMessage(shares []string, conf *Config) (string, error) { } // prepare the configuration to recover the message if err := conf.prepare(recoverOp); err != nil { - return "", err + return nil, err } - // convert shares to big.Ints points coordinates - x := make([]*big.Int, len(shares)) - y := make([]*big.Int, len(shares)) - for i, strShare := range shares { - index, share, err := strToShare(strShare) + parts := []*big.Int{} + for _, secretShares := range shares { + // convert shares to big.Ints points coordinates + xs, ys, err := decodeShares(secretShares) if err != nil { - return "", err + return nil, err } - x[i] = index - y[i] = share + // calculate the secret part using the Lagrange interpolation, the + // secret part is the y coordinate for x = 0 + result := lagrangeInterpolation(xs, ys, conf.Prime, big.NewInt(0)) + // append the secret part to the result + parts = append(parts, result) } - result := lagrangeInterpolation(x, y, conf.Prime, big.NewInt(0)) - return bigIntToMsg(result), nil + // decode the message from the parts and return it + return decodeMessage(parts), nil } diff --git a/sss_test.go b/sss_test.go index e86610f..78d91cb 100644 --- a/sss_test.go +++ b/sss_test.go @@ -1,46 +1,43 @@ package gosss import ( + "bytes" "math/rand" "testing" ) -const examplePrivateMessage = "aaa" +var examplePrivateMessage = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque nisl turpis, molestie sit amet ullamcorper sit amet, cursus in diam. Aenean urna nunc, hendrerit sed ipsum suscipit, lacinia feugiat metus. Phasellus pulvinar, tellus sit amet euismod vulputate, justo nisi finibus tellus, a ultrices odio mi vitae nibh. Duis accumsan nunc.") func TestHideRecoverMessage(t *testing.T) { config := &Config{ - Shares: 8, - Min: 7, + Shares: 4, + Min: 3, } totalShares, err := HideMessage(examplePrivateMessage, config) if err != nil { t.Errorf("unexpected error: %v", err) return } - if len(totalShares) != config.Shares { - t.Errorf("unexpected number of shares: %d", len(totalShares)) - return - } - // get some shares randomly of the total and recover the message - shares := []string{} - chosen := map[string]int{} - for len(chosen) < config.Min { - // random number between 0 and 35 - idx := rand.Intn(config.Shares) - _, ok := chosen[totalShares[idx]] - for ok { - idx = rand.Intn(config.Shares) - _, ok = chosen[totalShares[idx]] + + candidateShares := [][]string{} + for _, secretShares := range totalShares { + // choose a random index to remove a share + shares := []string{} + index := rand.Intn(len(secretShares)) + for i, share := range secretShares { + if i == index { + continue + } + shares = append(shares, share) } - chosen[totalShares[idx]] = idx - shares = append(shares, totalShares[idx]) + candidateShares = append(candidateShares, shares) } - message, err := RecoverMessage(shares, config) + message, err := RecoverMessage(candidateShares, config) if err != nil { t.Errorf("unexpected error: %v", err) return } - if examplePrivateMessage != message { + if !bytes.Equal(message, examplePrivateMessage) { t.Errorf("unexpected message: %s", message) } } diff --git a/test/main.go b/test/main.go new file mode 100644 index 0000000..9ee29c8 --- /dev/null +++ b/test/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "log" + "math/rand" + + "github.com/lucasmenendez/gosss" +) + +func main() { + // create a configuration with 8 shares and 7 minimum shares to recover the + // message + config := &gosss.Config{ + Shares: 4, + Min: 3, + } + // hide a message with the defined configuration + msg := "688641b753f1c97526d6a767058a80fd6c6519f5bdb0a08098986b0478c8502b" + log.Println("message to hide: ", msg) + totalShares, err := gosss.HideMessage(msg, config) + if err != nil { + log.Fatalf("error hiding message: %v", err) + } + // print every share and exclude one share to test the recovery + requiredShares := [][]string{} + for _, secretShares := range totalShares { + log.Printf("shares: %v", secretShares) + // choose a random share to exclude + index := rand.Intn(len(secretShares)) + shares := []string{} + for i, share := range secretShares { + if i == index { + continue + } + shares = append(shares, share) + } + requiredShares = append(requiredShares, shares) + } + // recover the message with the required shares + message, err := gosss.RecoverMessage(requiredShares, nil) + if err != nil { + log.Fatalf("error recovering message: %v", err) + } + log.Printf("recovered message: %s", message) +}