Skip to content

add merkleProof type #1014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 0 additions & 97 deletions stores/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -4308,99 +4307,3 @@ func TestUpdateObjectReuseSlab(t *testing.T) {
}
}
}

func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert currencies in random order
if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil {
t.Fatal(err)
}

// fetch currencies and assert they're sorted
var currencies []bCurrency
if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(&currencies).Error; err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}

// convenience variables
c0 := currencies[0]
c1 := currencies[1]
cM := currencies[2]

tests := []struct {
a bCurrency
b bCurrency
cmp string
}{
{
a: c0,
b: c1,
cmp: "<",
},
{
a: c1,
b: c0,
cmp: ">",
},
{
a: c0,
b: c1,
cmp: "!=",
},
{
a: c1,
b: c1,
cmp: "=",
},
{
a: c0,
b: cM,
cmp: "<",
},
{
a: cM,
b: c0,
cmp: ">",
},
{
a: cM,
b: cM,
cmp: "=",
},
}
for i, test := range tests {
var result bool
query := fmt.Sprintf("SELECT ? %s ?", test.cmp)
if !isSQLite(ss.db) {
query = strings.Replace(query, "?", "HEX(?)", -1)
}
if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil {
t.Fatal(err)
} else if !result {
t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String())
} else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 {
t.Fatal("invalid result")
} else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 {
t.Fatal("invalid result")
} else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 {
t.Fatal("invalid result")
}
}
}
39 changes: 39 additions & 0 deletions stores/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
)

const (
proofHashSize = 32
secretKeySize = 32
)

Expand All @@ -35,6 +36,10 @@ type (
balance big.Int
unsigned64 uint64 // used for storing large uint64 values in sqlite
secretKey []byte

// NOTE: we have to wrap the proof here because Gorm can't scan bytes into
// multiple slices, all bytes are scanned into the first row
merkleProof struct{ proof []types.Hash256 }
)

// GormDataType implements gorm.GormDataTypeInterface.
Expand Down Expand Up @@ -341,6 +346,7 @@ func (u unsigned64) Value() (driver.Value, error) {
return int64(u), nil
}

// GormDataType implements gorm.GormDataTypeInterface.
func (bCurrency) GormDataType() string {
return "bytes"
}
Expand All @@ -366,3 +372,36 @@ func (sc bCurrency) Value() (driver.Value, error) {
binary.BigEndian.PutUint64(buf[8:], sc.Lo)
return buf, nil
}

// GormDataType implements gorm.GormDataTypeInterface.
func (mp *merkleProof) GormDataType() string {
return "bytes"
}

// Scan scans value into mp, implements sql.Scanner interface.
func (mp *merkleProof) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("failed to unmarshal merkleProof value:", value))
} else if len(bytes) == 0 || len(bytes)%proofHashSize != 0 {
return fmt.Errorf("failed to unmarshal merkleProof value due to invalid number of bytes %v", len(bytes))
}

n := len(bytes) / proofHashSize
mp.proof = make([]types.Hash256, n)
for i := 0; i < n; i++ {
copy(mp.proof[i][:], bytes[:proofHashSize])
bytes = bytes[proofHashSize:]
}
return nil
}

// Value returns a merkle proof value, implements driver.Valuer interface.
func (mp merkleProof) Value() (driver.Value, error) {
var i int
out := make([]byte, len(mp.proof)*proofHashSize)
for _, ph := range mp.proof {
i += copy(out[i:], ph[:])
}
return out, nil
}
154 changes: 154 additions & 0 deletions stores/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package stores

import (
"fmt"
"sort"
"strings"
"testing"

"go.sia.tech/core/types"
)

func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert currencies in random order
if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil {
t.Fatal(err)
}

// fetch currencies and assert they're sorted
var currencies []bCurrency
if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(&currencies).Error; err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}

// convenience variables
c0 := currencies[0]
c1 := currencies[1]
cM := currencies[2]

tests := []struct {
a bCurrency
b bCurrency
cmp string
}{
{
a: c0,
b: c1,
cmp: "<",
},
{
a: c1,
b: c0,
cmp: ">",
},
{
a: c0,
b: c1,
cmp: "!=",
},
{
a: c1,
b: c1,
cmp: "=",
},
{
a: c0,
b: cM,
cmp: "<",
},
{
a: cM,
b: c0,
cmp: ">",
},
{
a: cM,
b: cM,
cmp: "=",
},
}
for i, test := range tests {
var result bool
query := fmt.Sprintf("SELECT ? %s ?", test.cmp)
if !isSQLite(ss.db) {
query = strings.Replace(query, "?", "HEX(?)", -1)
}
if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil {
t.Fatal(err)
} else if !result {
t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String())
} else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 {
t.Fatal("invalid result")
} else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 {
t.Fatal("invalid result")
} else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 {
t.Fatal("invalid result")
}
}
}

func TestTypeMerkleProof(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
ss.db.Exec("DROP TABLE IF EXISTS merkle_proofs;")
if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert merkle proof
mp1 := merkleProof{proof: []types.Hash256{{3}, {1}, {2}}}
mp2 := merkleProof{proof: []types.Hash256{{4}}}
if err := ss.db.Exec("INSERT INTO merkle_proofs (merkle_proof) VALUES (?), (?);", mp1, mp2).Error; err != nil {
t.Fatal(err)
}

// fetch first proof
var first merkleProof
if err := ss.db.
Raw(`SELECT merkle_proof FROM merkle_proofs`).
Take(&first).
Error; err != nil {
t.Fatal(err)
} else if first.proof[0] != (types.Hash256{3}) || first.proof[1] != (types.Hash256{1}) || first.proof[2] != (types.Hash256{2}) {
t.Fatalf("unexpected proof %+v", first)
}

// fetch both proofs
var both []merkleProof
if err := ss.db.
Raw(`SELECT merkle_proof FROM merkle_proofs`).
Scan(&both).
Error; err != nil {
t.Fatal(err)
} else if len(both) != 2 {
t.Fatalf("unexpected number of proofs: %d", len(both))
} else if both[1].proof[0] != (types.Hash256{4}) {
t.Fatalf("unexpected proof %+v", both)
}
}