diff --git a/go.mod b/go.mod index 2708e71..43c2589 100644 --- a/go.mod +++ b/go.mod @@ -229,6 +229,7 @@ require ( github.com/schollz/progressbar/v3 v3.18.0 // indirect github.com/securego/gosec/v2 v2.22.11 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sonatard/noctx v0.4.0 // indirect diff --git a/go.sum b/go.sum index 0a98b2f..48b4b50 100644 --- a/go.sum +++ b/go.sum @@ -600,6 +600,8 @@ github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible h1:Bn1aCHHRnjv4Bl16T8rcaFjYSrGrIZvpiGO6P3Q4GpU= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/liquidity/errors.go b/internal/liquidity/errors.go index 07f6b1c..11baba7 100644 --- a/internal/liquidity/errors.go +++ b/internal/liquidity/errors.go @@ -10,8 +10,13 @@ var ( ErrInvalidBinSize = errors.New("bin size must be positive") // ErrInvalidTickRange is returned when tick range is invalid. - ErrInvalidTickRange = errors.New("tick range must be positive") + ErrInvalidTickRange = errors.New("invalid tick range") // ErrContractCall is returned when contract call fails. ErrContractCall = errors.New("contract call failed") + + ErrNoTickRanges = errors.New("no tick ranges") + ErrInvalidWeight = errors.New("invalid weight") + ErrInvalidTotalAmount = errors.New("invalid total amount") + ErrZeroTotalWeight = errors.New("zero total weight") ) diff --git a/internal/liquidity/rebalance.go b/internal/liquidity/rebalance.go new file mode 100644 index 0000000..b61fa91 --- /dev/null +++ b/internal/liquidity/rebalance.go @@ -0,0 +1,55 @@ +package liquidity + +import "github.com/shopspring/decimal" + +func BuildRebalanceAllocations( + ranges []TickRangeWeight, + totalAmount decimal.Decimal, +) ([]RebalanceAllocation, error) { + if len(ranges) == 0 { + return nil, ErrNoTickRanges + } + + if totalAmount.IsNegative() { + return nil, ErrInvalidTotalAmount + } + + sumWeight := decimal.Zero + for _, r := range ranges { + if r.TickLower >= r.TickUpper { + return nil, ErrInvalidTickRange + } + + if r.Weight.IsNegative() { + return nil, ErrInvalidWeight + } + + sumWeight = sumWeight.Add(r.Weight) + } + + if sumWeight.IsZero() { + return nil, ErrZeroTotalWeight + } + + allocations := make([]RebalanceAllocation, len(ranges)) + remaining := totalAmount + + for i, r := range ranges { + var amount decimal.Decimal + if i == len(ranges)-1 { + amount = remaining + } else { + amount = totalAmount.Mul(r.Weight).Div(sumWeight) + remaining = remaining.Sub(amount) + } + + allocations[i] = RebalanceAllocation{ + TickLower: r.TickLower, + TickUpper: r.TickUpper, + Weight: r.Weight, + Amount: amount, + } + } + + return allocations, nil +} diff --git a/internal/liquidity/rebalance_test.go b/internal/liquidity/rebalance_test.go new file mode 100644 index 0000000..94d41ff --- /dev/null +++ b/internal/liquidity/rebalance_test.go @@ -0,0 +1,172 @@ +package liquidity_test + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/shopspring/decimal" + + "remora/internal/liquidity" +) + +func TestBuildRebalanceAllocations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ranges []liquidity.TickRangeWeight + amount decimal.Decimal + want []liquidity.RebalanceAllocation + wantErr bool + wantErrIs error + wantAmount decimal.Decimal + }{ + { + name: "success - weighted allocation", + ranges: []liquidity.TickRangeWeight{ + {TickLower: -100, TickUpper: 0, Weight: mustDecimal(t, "1")}, + {TickLower: 0, TickUpper: 100, Weight: mustDecimal(t, "3")}, + }, + amount: mustDecimal(t, "100"), + want: []liquidity.RebalanceAllocation{ + { + TickLower: -100, + TickUpper: 0, + Weight: mustDecimal(t, "1"), + Amount: mustDecimal(t, "25"), + }, + { + TickLower: 0, + TickUpper: 100, + Weight: mustDecimal(t, "3"), + Amount: mustDecimal(t, "75"), + }, + }, + wantAmount: mustDecimal(t, "100"), + }, + { + name: "success - zero weight keeps entry with zero amounts", + ranges: []liquidity.TickRangeWeight{ + {TickLower: -200, TickUpper: -100, Weight: mustDecimal(t, "0")}, + {TickLower: -100, TickUpper: 100, Weight: mustDecimal(t, "2")}, + }, + amount: mustDecimal(t, "10"), + want: []liquidity.RebalanceAllocation{ + { + TickLower: -200, + TickUpper: -100, + Weight: mustDecimal(t, "0"), + Amount: mustDecimal(t, "0"), + }, + { + TickLower: -100, + TickUpper: 100, + Weight: mustDecimal(t, "2"), + Amount: mustDecimal(t, "10"), + }, + }, + wantAmount: mustDecimal(t, "10"), + }, + { + name: "error - no ranges", + ranges: nil, + amount: mustDecimal(t, "1"), + wantErr: true, + wantErrIs: liquidity.ErrNoTickRanges, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - invalid tick range", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 100, TickUpper: 100, Weight: mustDecimal(t, "1")}, + }, + amount: mustDecimal(t, "1"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidTickRange, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - negative weight", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "-1")}, + }, + amount: mustDecimal(t, "1"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidWeight, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - negative total amount", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "1")}, + }, + amount: mustDecimal(t, "-1"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidTotalAmount, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - total weight zero", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "0")}, + {TickLower: 10, TickUpper: 20, Weight: mustDecimal(t, "0")}, + }, + amount: mustDecimal(t, "1"), + wantErr: true, + wantErrIs: liquidity.ErrZeroTotalWeight, + wantAmount: mustDecimal(t, "0"), + }, + } + + decimalComparer := cmp.Comparer(func(a, b decimal.Decimal) bool { + return a.Equal(b) + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := liquidity.BuildRebalanceAllocations(tt.ranges, tt.amount) + if err != nil { + if !tt.wantErr { + t.Fatalf("BuildRebalanceAllocations() failed: %v", err) + } + + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Fatalf("BuildRebalanceAllocations() error = %v, want %v", err, tt.wantErrIs) + } + + return + } + + if tt.wantErr { + t.Fatalf("BuildRebalanceAllocations() expected error") + } + + if diff := cmp.Diff(tt.want, got, decimalComparer); diff != "" { + t.Fatalf("BuildRebalanceAllocations() mismatch (-want +got):\n%s", diff) + } + + sum := decimal.Zero + for _, allocation := range got { + sum = sum.Add(allocation.Amount) + } + + if !sum.Equal(tt.wantAmount) { + t.Fatalf("BuildRebalanceAllocations() amount sum = %s, want %s", sum, tt.wantAmount) + } + }) + } +} + +func mustDecimal(t *testing.T, value string) decimal.Decimal { + t.Helper() + + dec, err := decimal.NewFromString(value) + if err != nil { + t.Fatalf("NewDecimalFromString(%q) failed: %v", value, err) + } + + return dec +} diff --git a/internal/liquidity/uniswap.go b/internal/liquidity/uniswap.go new file mode 100644 index 0000000..b4219ce --- /dev/null +++ b/internal/liquidity/uniswap.go @@ -0,0 +1,16 @@ +package liquidity + +import "github.com/shopspring/decimal" + +type TickRangeWeight struct { + TickLower int + TickUpper int + Weight decimal.Decimal +} + +type RebalanceAllocation struct { + TickLower int + TickUpper int + Weight decimal.Decimal + Amount decimal.Decimal +}