Skip to content

Commit

Permalink
fix parallel tests and break balancer api
Browse files Browse the repository at this point in the history
  • Loading branch information
pysel committed Feb 12, 2024
1 parent 744f6ae commit 20fbeee
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 61 deletions.
10 changes: 3 additions & 7 deletions balancer/2pc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,16 @@ import (
"github.com/stretchr/testify/require"
)

var (
TestDBBalancer = "balancer"
)

func TestTwoPhaseCommit(t *testing.T) {
defer os.RemoveAll(TestDBBalancer + t.Name())
addrs, paths := testutil.StartXPartitionServers(2)
defer os.RemoveAll(balancer.BalancerDBPath + t.Name())
addrs, paths := testutil.StartXPartitionServers(t, 2)
defer testutil.RemovePaths(paths)

ctx := context.Background()

partitionAddr1, partitionAddr2 := addrs[0], addrs[1]

b := balancer.NewBalancerTest(t, 1)
b := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), 1)

err := b.RegisterPartition(ctx, partitionAddr1.String())
require.NoError(t, err)
Expand Down
10 changes: 5 additions & 5 deletions balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ type Balancer struct {
}

// NewBalancer returns a new balancer instance.
func NewBalancer(goalReplicaRanges int) *Balancer {
db, err := leveldb.NewLevelDB(BalancerDBPath)
func NewBalancer(dbPath string, goalReplicaRanges int) *Balancer {
db, err := leveldb.NewLevelDB(dbPath)
if err != nil {
panic(err)
}

b := &Balancer{
DB: db,
rangeToViews: make(map[hashrange.RangeKey]*rangeview.RangeView),
coverage: coverage.GetCoverage(),
coverage: &coverage.Coverage{Ticks: nil},
clientIdToLamport: NewClientIdToLamport(),
}

Expand Down Expand Up @@ -160,11 +160,11 @@ func (b *Balancer) setupCoverage(goalReplicaRanges int) error {
b.coverage.AddTick(coverage.NewTick(hashrange.MaxInt, 0))
return nil
}

// Create a tick for each partition
for i := 0; i <= goalReplicaRanges; i++ {
numerator := new(big.Int).Mul(big.NewInt(int64(i)), hashrange.MaxInt)
value := new(big.Int).Div(numerator, big.NewInt(int64(goalReplicaRanges)))
denominator := big.NewInt(int64(goalReplicaRanges))
value := new(big.Int).Div(numerator, denominator)
b.coverage.AddTick(coverage.NewTick(value, 0))
}

Expand Down
20 changes: 10 additions & 10 deletions balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
)

func TestRegisterGetPartition(t *testing.T) {
defer os.RemoveAll(TestDBBalancer + t.Name())
addrs, paths := testutil.StartXPartitionServers(1)
defer os.RemoveAll(balancer.BalancerDBPath + t.Name())
addrs, paths := testutil.StartXPartitionServers(t, 1)
defer testutil.RemovePaths(paths)

ctx := context.Background()

addr := addrs[0]
b2 := balancer.NewBalancerTest(t, 2)
b2 := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), 2)

err := b2.RegisterPartition(ctx, addr.String())
require.NoError(t, err)
Expand All @@ -33,11 +33,11 @@ func TestRegisterGetPartition(t *testing.T) {
}

func TestBalancerInit(t *testing.T) {
defer os.RemoveAll(TestDBBalancer + t.Name())
defer os.RemoveAll(balancer.BalancerDBPath + t.Name())

goalReplicaRanges := 3

b := balancer.NewBalancerTest(t, goalReplicaRanges)
b := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), goalReplicaRanges)
require.Equal(t, b.GetCoverageSize(), goalReplicaRanges+1)

expectedFirstTickValue := big.NewInt(0)
Expand All @@ -55,16 +55,16 @@ func TestBalancerInit(t *testing.T) {
}

func TestGetNextPartitionRange(t *testing.T) {
defer os.RemoveAll(TestDBBalancer + t.Name())
addrs, paths := testutil.StartXPartitionServers(2)
defer os.RemoveAll(balancer.BalancerDBPath + t.Name())
addrs, paths := testutil.StartXPartitionServers(t, 2)
defer testutil.RemovePaths(paths)

addr1, addr2 := addrs[0], addrs[1]

ctx := context.Background()

// SUT
b2 := balancer.NewBalancerTest(t, 2)
b2 := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), 2)
nextPartitionRange, _, _ := b2.GetNextPartitionRange()
// defaultHashrange is full sha256 domain, in case of 2 nodes, first node's domain should be half
require.Equal(t, hashrange.NewRange(big.NewInt(0).Bytes(), testutil.HalfShaDomain.Bytes()).AsKey(), nextPartitionRange)
Expand All @@ -89,9 +89,9 @@ func TestGetNextPartitionRange(t *testing.T) {
}

func TestClientIdToLamport(t *testing.T) {
defer os.RemoveAll(TestDBBalancer + t.Name())
defer os.RemoveAll(balancer.BalancerDBPath + t.Name())

b := balancer.NewBalancerTest(t, 2)
b := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), 2)

require.Equal(t, uint64(1), b.NextClientId()) // first call should return 1

Expand Down
12 changes: 6 additions & 6 deletions balancer/coverage/coverage.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ func (c *Coverage) String() string {

// GetCoverage returns a Coverage.
// Singletone pattern is used here.
func GetCoverage() *Coverage {
if CreatedCoverage == nil {
CreatedCoverage = &Coverage{nil}
}
return CreatedCoverage
}
// func GetCoverage() *Coverage {
// if CreatedCoverage == nil {
// CreatedCoverage = &Coverage{nil}
// }
// return &Coverage{nil}
// }

// addTick iterates over the list of ticks until
func (c *Coverage) AddTick(t *pbbalancer.Tick) {
Expand Down
24 changes: 0 additions & 24 deletions balancer/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@ package balancer

import (
"math/big"
"testing"

coverage "github.com/pysel/dkvs/balancer/coverage"
"github.com/pysel/dkvs/balancer/rangeview"
leveldb "github.com/pysel/dkvs/db/leveldb"
pbbalancer "github.com/pysel/dkvs/prototypes/balancer"
"github.com/pysel/dkvs/types/hashrange"
"github.com/stretchr/testify/require"
)

var balancerName = "balancer"

type (
ClientIdToLamport clientIdToLamport
)
Expand All @@ -37,21 +31,3 @@ func (b *Balancer) GetRangeFromDigest(digest []byte) (*hashrange.Range, error) {
func (b *Balancer) GetRangeToViews() map[hashrange.RangeKey]*rangeview.RangeView {
return b.rangeToViews
}

// NewBalancerTest returns a new balancer instance with an independent Coverage every time.
func NewBalancerTest(t *testing.T, goalReplicaRanges int) *Balancer {
balancerName = "balancer" + t.Name()
db, err := leveldb.NewLevelDB(balancerName)

require.NoError(t, err)

b := &Balancer{
DB: db,
rangeToViews: make(map[hashrange.RangeKey]*rangeview.RangeView),
coverage: &coverage.Coverage{},
clientIdToLamport: NewClientIdToLamport(),
}

require.NoError(t, b.setupCoverage(goalReplicaRanges))
return b
}
4 changes: 2 additions & 2 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var (

func TestClient(t *testing.T) {
// setup balancer server to which the client will be connected
balancerAddress, closer := testutil.BalancerClientWith2Partitions()
balancerAddress, closer := testutil.BalancerClientWith2Partitions(t)

defer closer()

Expand All @@ -43,7 +43,7 @@ func TestClient(t *testing.T) {

func TestClientParallel(t *testing.T) {
// setup balancer server to which the client will be connected
balancerAddress, closer := testutil.BalancerClientWith2Partitions()
balancerAddress, closer := testutil.BalancerClientWith2Partitions(t)

defer closer()

Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func main() {
panic("invalid parameter for desired amount of partitions")
}

b := balancer.NewBalancer(goalReplicas)
b := balancer.NewBalancer(balancer.BalancerDBPath, goalReplicas)
server := balancer.RegisterBalancerServer(b)

wg, addr := shared.StartListeningOnPort(server, uint64(port))
Expand Down
13 changes: 7 additions & 6 deletions testutil/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"os"
"strconv"
"testing"

"github.com/pysel/dkvs/balancer"

Expand Down Expand Up @@ -94,11 +95,11 @@ func StartPartitionClientToBufferedServer(ctx context.Context) (net.Addr, pbpart
return lis.Addr(), pbpartition.NewPartitionServiceClient(conn), closer
}

func StartXPartitionServers(x int) ([]net.Addr, []string) {
func StartXPartitionServers(t *testing.T, x int) ([]net.Addr, []string) {
addrs := make([]net.Addr, x)
dbPaths := make([]string, x)
for i := 0; i < x; i++ {
path := TestDBPath + strconv.Itoa(i) + "test"
path := TestDBPath + strconv.Itoa(i) + "test" + t.Name()
p := partition.NewPartition(path)
s := partition.RegisterPartitionServer(p)
_, addr := shared.StartListeningOnPort(s, 0)
Expand All @@ -109,12 +110,12 @@ func StartXPartitionServers(x int) ([]net.Addr, []string) {
return addrs, dbPaths
}

func BalancerClientWith2Partitions() (net.Addr, func()) {
func BalancerClientWith2Partitions(t *testing.T) (net.Addr, func()) {
ctx := context.Background()
addrs, dbPaths := StartXPartitionServers(2)
addrs, dbPaths := StartXPartitionServers(t, 2)

// register partitions
b := balancer.NewBalancer(2)
b := balancer.NewBalancer(balancer.BalancerDBPath+t.Name(), 2)
b.RegisterPartition(ctx, addrs[0].String())
b.RegisterPartition(ctx, addrs[1].String())

Expand All @@ -123,7 +124,7 @@ func BalancerClientWith2Partitions() (net.Addr, func()) {

return addr, func() {
// remove all databases - one for balancer and one for each partitin
os.RemoveAll(balancer.BalancerDBPath)
os.RemoveAll(balancer.BalancerDBPath + t.Name())
for _, path := range dbPaths {
err := os.RemoveAll(path)
if err != nil {
Expand Down

0 comments on commit 20fbeee

Please sign in to comment.