From d0c2e3f97efd1db71dc859f1839c4e47e04209da Mon Sep 17 00:00:00 2001 From: --global Date: Tue, 3 Feb 2026 17:06:46 +0800 Subject: [PATCH 1/2] feat: Implement rebalance allocation logic and add related error handling - Introduced `BuildRebalanceAllocations` function to calculate allocations based on tick ranges, total amount, and slippage tolerance. - Added new error types for better validation of input parameters, including invalid tick ranges, weights, total amounts, and slippage. - Created unit tests for the rebalance allocation logic to ensure correctness and handle various edge cases. - Updated `go.mod` and `go.sum` to include the `shopspring/decimal` dependency for precise decimal arithmetic. --- go.mod | 1 + go.sum | 2 + internal/liquidity/errors.go | 8 +- internal/liquidity/rebalance.go | 62 +++++++++ internal/liquidity/rebalance_test.go | 195 +++++++++++++++++++++++++++ internal/liquidity/uniswap.go | 17 +++ 6 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 internal/liquidity/rebalance.go create mode 100644 internal/liquidity/rebalance_test.go create mode 100644 internal/liquidity/uniswap.go diff --git a/go.mod b/go.mod index 2708e71..43c2589 100644 --- a/go.mod +++ b/go.mod @@ -229,6 +229,7 @@ require ( github.com/schollz/progressbar/v3 v3.18.0 // indirect github.com/securego/gosec/v2 v2.22.11 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sonatard/noctx v0.4.0 // indirect diff --git a/go.sum b/go.sum index 0a98b2f..48b4b50 100644 --- a/go.sum +++ b/go.sum @@ -600,6 +600,8 @@ github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible h1:Bn1aCHHRnjv4Bl16T8rcaFjYSrGrIZvpiGO6P3Q4GpU= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/liquidity/errors.go b/internal/liquidity/errors.go index 07f6b1c..a8a0029 100644 --- a/internal/liquidity/errors.go +++ b/internal/liquidity/errors.go @@ -10,8 +10,14 @@ var ( ErrInvalidBinSize = errors.New("bin size must be positive") // ErrInvalidTickRange is returned when tick range is invalid. - ErrInvalidTickRange = errors.New("tick range must be positive") + ErrInvalidTickRange = errors.New("invalid tick range") // ErrContractCall is returned when contract call fails. ErrContractCall = errors.New("contract call failed") + + ErrNoTickRanges = errors.New("no tick ranges") + ErrInvalidWeight = errors.New("invalid weight") + ErrInvalidTotalAmount = errors.New("invalid total amount") + ErrInvalidSlippage = errors.New("invalid slippage tolerance") + ErrZeroTotalWeight = errors.New("zero total weight") ) diff --git a/internal/liquidity/rebalance.go b/internal/liquidity/rebalance.go new file mode 100644 index 0000000..15d88ef --- /dev/null +++ b/internal/liquidity/rebalance.go @@ -0,0 +1,62 @@ +package liquidity + +import "github.com/shopspring/decimal" + +func BuildRebalanceAllocations( + ranges []TickRangeWeight, + totalAmount decimal.Decimal, + slippageTolerance decimal.Decimal, +) ([]RebalanceAllocation, error) { + if len(ranges) == 0 { + return nil, ErrNoTickRanges + } + + if totalAmount.IsNegative() { + return nil, ErrInvalidTotalAmount + } + + if slippageTolerance.IsNegative() || slippageTolerance.GreaterThan(decimal.NewFromInt(1)) { + return nil, ErrInvalidSlippage + } + + sumWeight := decimal.NewFromInt(0) + for _, r := range ranges { + if r.TickLower >= r.TickUpper { + return nil, ErrInvalidTickRange + } + + if r.Weight.LessThan(decimal.NewFromInt(0)) { + return nil, ErrInvalidWeight + } + + sumWeight = sumWeight.Add(r.Weight) + } + + if sumWeight.IsZero() { + return nil, ErrZeroTotalWeight + } + + allocations := make([]RebalanceAllocation, 0, len(ranges)) + remaining := totalAmount + minFactor := decimal.NewFromInt(1).Sub(slippageTolerance) + + for i, r := range ranges { + var amount decimal.Decimal + if i == len(ranges)-1 { + amount = remaining + } else { + amount = totalAmount.Mul(r.Weight).DivRound(sumWeight, 18) + remaining = remaining.Sub(amount) + } + + allocations = append(allocations, RebalanceAllocation{ + TickLower: r.TickLower, + TickUpper: r.TickUpper, + Weight: r.Weight, + Amount: amount, + AmountMin: amount.Mul(minFactor).Round(18), + }) + } + + return allocations, nil +} diff --git a/internal/liquidity/rebalance_test.go b/internal/liquidity/rebalance_test.go new file mode 100644 index 0000000..254a601 --- /dev/null +++ b/internal/liquidity/rebalance_test.go @@ -0,0 +1,195 @@ +package liquidity_test + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/shopspring/decimal" + + "remora/internal/liquidity" +) + +func TestBuildRebalanceAllocations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ranges []liquidity.TickRangeWeight + amount decimal.Decimal + slippage decimal.Decimal + want []liquidity.RebalanceAllocation + wantErr bool + wantErrIs error + wantAmount decimal.Decimal + }{ + { + name: "success - weighted allocation with slippage", + ranges: []liquidity.TickRangeWeight{ + {TickLower: -100, TickUpper: 0, Weight: mustDecimal(t, "1")}, + {TickLower: 0, TickUpper: 100, Weight: mustDecimal(t, "3")}, + }, + amount: mustDecimal(t, "100"), + slippage: mustDecimal(t, "0.01"), + want: []liquidity.RebalanceAllocation{ + { + TickLower: -100, + TickUpper: 0, + Weight: mustDecimal(t, "1"), + Amount: mustDecimal(t, "25"), + AmountMin: mustDecimal(t, "24.75"), + }, + { + TickLower: 0, + TickUpper: 100, + Weight: mustDecimal(t, "3"), + Amount: mustDecimal(t, "75"), + AmountMin: mustDecimal(t, "74.25"), + }, + }, + wantAmount: mustDecimal(t, "100"), + }, + { + name: "success - zero weight keeps entry with zero amounts", + ranges: []liquidity.TickRangeWeight{ + {TickLower: -200, TickUpper: -100, Weight: mustDecimal(t, "0")}, + {TickLower: -100, TickUpper: 100, Weight: mustDecimal(t, "2")}, + }, + amount: mustDecimal(t, "10"), + slippage: mustDecimal(t, "0.05"), + want: []liquidity.RebalanceAllocation{ + { + TickLower: -200, + TickUpper: -100, + Weight: mustDecimal(t, "0"), + Amount: mustDecimal(t, "0"), + AmountMin: mustDecimal(t, "0"), + }, + { + TickLower: -100, + TickUpper: 100, + Weight: mustDecimal(t, "2"), + Amount: mustDecimal(t, "10"), + AmountMin: mustDecimal(t, "9.5"), + }, + }, + wantAmount: mustDecimal(t, "10"), + }, + { + name: "error - no ranges", + ranges: nil, + amount: mustDecimal(t, "1"), + slippage: mustDecimal(t, "0"), + wantErr: true, + wantErrIs: liquidity.ErrNoTickRanges, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - invalid tick range", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 100, TickUpper: 100, Weight: mustDecimal(t, "1")}, + }, + amount: mustDecimal(t, "1"), + slippage: mustDecimal(t, "0"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidTickRange, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - negative weight", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "-1")}, + }, + amount: mustDecimal(t, "1"), + slippage: mustDecimal(t, "0"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidWeight, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - negative total amount", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "1")}, + }, + amount: mustDecimal(t, "-1"), + slippage: mustDecimal(t, "0"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidTotalAmount, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - slippage > 1", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "1")}, + }, + amount: mustDecimal(t, "1"), + slippage: mustDecimal(t, "1.01"), + wantErr: true, + wantErrIs: liquidity.ErrInvalidSlippage, + wantAmount: mustDecimal(t, "0"), + }, + { + name: "error - total weight zero", + ranges: []liquidity.TickRangeWeight{ + {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "0")}, + {TickLower: 10, TickUpper: 20, Weight: mustDecimal(t, "0")}, + }, + amount: mustDecimal(t, "1"), + slippage: mustDecimal(t, "0"), + wantErr: true, + wantErrIs: liquidity.ErrZeroTotalWeight, + wantAmount: mustDecimal(t, "0"), + }, + } + + decimalComparer := cmp.Comparer(func(a, b decimal.Decimal) bool { + return a.Equal(b) + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := liquidity.BuildRebalanceAllocations(tt.ranges, tt.amount, tt.slippage) + if err != nil { + if !tt.wantErr { + t.Fatalf("BuildRebalanceAllocations() failed: %v", err) + } + + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Fatalf("BuildRebalanceAllocations() error = %v, want %v", err, tt.wantErrIs) + } + + return + } + + if tt.wantErr { + t.Fatalf("BuildRebalanceAllocations() expected error") + } + + if diff := cmp.Diff(tt.want, got, decimalComparer); diff != "" { + t.Fatalf("BuildRebalanceAllocations() mismatch (-want +got):\n%s", diff) + } + + sum := decimal.NewFromInt(0) + for _, allocation := range got { + sum = sum.Add(allocation.Amount) + } + + if !sum.Equal(tt.wantAmount) { + t.Fatalf("BuildRebalanceAllocations() amount sum = %s, want %s", sum, tt.wantAmount) + } + }) + } +} + +func mustDecimal(t *testing.T, value string) decimal.Decimal { + t.Helper() + + dec, err := decimal.NewFromString(value) + if err != nil { + t.Fatalf("NewDecimalFromString(%q) failed: %v", value, err) + } + + return dec +} diff --git a/internal/liquidity/uniswap.go b/internal/liquidity/uniswap.go new file mode 100644 index 0000000..205413b --- /dev/null +++ b/internal/liquidity/uniswap.go @@ -0,0 +1,17 @@ +package liquidity + +import "github.com/shopspring/decimal" + +type TickRangeWeight struct { + TickLower int + TickUpper int + Weight decimal.Decimal +} + +type RebalanceAllocation struct { + TickLower int + TickUpper int + Weight decimal.Decimal + Amount decimal.Decimal + AmountMin decimal.Decimal +} From 3b726d3026fc559eacd580f6e3545087b539b0fd Mon Sep 17 00:00:00 2001 From: --global Date: Tue, 3 Feb 2026 17:29:18 +0800 Subject: [PATCH 2/2] refactor: Remove slippage tolerance from rebalance allocation logic - Eliminated slippage tolerance parameter from `BuildRebalanceAllocations` function and related tests to simplify allocation calculations. - Updated error handling to focus on tick ranges, weights, and total amounts without slippage considerations. - Adjusted unit tests to reflect the removal of slippage, ensuring correctness in allocation logic. --- internal/liquidity/errors.go | 1 - internal/liquidity/rebalance.go | 19 +++++----------- internal/liquidity/rebalance_test.go | 33 +++++----------------------- internal/liquidity/uniswap.go | 1 - 4 files changed, 11 insertions(+), 43 deletions(-) diff --git a/internal/liquidity/errors.go b/internal/liquidity/errors.go index a8a0029..11baba7 100644 --- a/internal/liquidity/errors.go +++ b/internal/liquidity/errors.go @@ -18,6 +18,5 @@ var ( ErrNoTickRanges = errors.New("no tick ranges") ErrInvalidWeight = errors.New("invalid weight") ErrInvalidTotalAmount = errors.New("invalid total amount") - ErrInvalidSlippage = errors.New("invalid slippage tolerance") ErrZeroTotalWeight = errors.New("zero total weight") ) diff --git a/internal/liquidity/rebalance.go b/internal/liquidity/rebalance.go index 15d88ef..b61fa91 100644 --- a/internal/liquidity/rebalance.go +++ b/internal/liquidity/rebalance.go @@ -5,7 +5,6 @@ import "github.com/shopspring/decimal" func BuildRebalanceAllocations( ranges []TickRangeWeight, totalAmount decimal.Decimal, - slippageTolerance decimal.Decimal, ) ([]RebalanceAllocation, error) { if len(ranges) == 0 { return nil, ErrNoTickRanges @@ -15,17 +14,13 @@ func BuildRebalanceAllocations( return nil, ErrInvalidTotalAmount } - if slippageTolerance.IsNegative() || slippageTolerance.GreaterThan(decimal.NewFromInt(1)) { - return nil, ErrInvalidSlippage - } - - sumWeight := decimal.NewFromInt(0) + sumWeight := decimal.Zero for _, r := range ranges { if r.TickLower >= r.TickUpper { return nil, ErrInvalidTickRange } - if r.Weight.LessThan(decimal.NewFromInt(0)) { + if r.Weight.IsNegative() { return nil, ErrInvalidWeight } @@ -36,26 +31,24 @@ func BuildRebalanceAllocations( return nil, ErrZeroTotalWeight } - allocations := make([]RebalanceAllocation, 0, len(ranges)) + allocations := make([]RebalanceAllocation, len(ranges)) remaining := totalAmount - minFactor := decimal.NewFromInt(1).Sub(slippageTolerance) for i, r := range ranges { var amount decimal.Decimal if i == len(ranges)-1 { amount = remaining } else { - amount = totalAmount.Mul(r.Weight).DivRound(sumWeight, 18) + amount = totalAmount.Mul(r.Weight).Div(sumWeight) remaining = remaining.Sub(amount) } - allocations = append(allocations, RebalanceAllocation{ + allocations[i] = RebalanceAllocation{ TickLower: r.TickLower, TickUpper: r.TickUpper, Weight: r.Weight, Amount: amount, - AmountMin: amount.Mul(minFactor).Round(18), - }) + } } return allocations, nil diff --git a/internal/liquidity/rebalance_test.go b/internal/liquidity/rebalance_test.go index 254a601..94d41ff 100644 --- a/internal/liquidity/rebalance_test.go +++ b/internal/liquidity/rebalance_test.go @@ -17,34 +17,30 @@ func TestBuildRebalanceAllocations(t *testing.T) { name string ranges []liquidity.TickRangeWeight amount decimal.Decimal - slippage decimal.Decimal want []liquidity.RebalanceAllocation wantErr bool wantErrIs error wantAmount decimal.Decimal }{ { - name: "success - weighted allocation with slippage", + name: "success - weighted allocation", ranges: []liquidity.TickRangeWeight{ {TickLower: -100, TickUpper: 0, Weight: mustDecimal(t, "1")}, {TickLower: 0, TickUpper: 100, Weight: mustDecimal(t, "3")}, }, - amount: mustDecimal(t, "100"), - slippage: mustDecimal(t, "0.01"), + amount: mustDecimal(t, "100"), want: []liquidity.RebalanceAllocation{ { TickLower: -100, TickUpper: 0, Weight: mustDecimal(t, "1"), Amount: mustDecimal(t, "25"), - AmountMin: mustDecimal(t, "24.75"), }, { TickLower: 0, TickUpper: 100, Weight: mustDecimal(t, "3"), Amount: mustDecimal(t, "75"), - AmountMin: mustDecimal(t, "74.25"), }, }, wantAmount: mustDecimal(t, "100"), @@ -55,22 +51,19 @@ func TestBuildRebalanceAllocations(t *testing.T) { {TickLower: -200, TickUpper: -100, Weight: mustDecimal(t, "0")}, {TickLower: -100, TickUpper: 100, Weight: mustDecimal(t, "2")}, }, - amount: mustDecimal(t, "10"), - slippage: mustDecimal(t, "0.05"), + amount: mustDecimal(t, "10"), want: []liquidity.RebalanceAllocation{ { TickLower: -200, TickUpper: -100, Weight: mustDecimal(t, "0"), Amount: mustDecimal(t, "0"), - AmountMin: mustDecimal(t, "0"), }, { TickLower: -100, TickUpper: 100, Weight: mustDecimal(t, "2"), Amount: mustDecimal(t, "10"), - AmountMin: mustDecimal(t, "9.5"), }, }, wantAmount: mustDecimal(t, "10"), @@ -79,7 +72,6 @@ func TestBuildRebalanceAllocations(t *testing.T) { name: "error - no ranges", ranges: nil, amount: mustDecimal(t, "1"), - slippage: mustDecimal(t, "0"), wantErr: true, wantErrIs: liquidity.ErrNoTickRanges, wantAmount: mustDecimal(t, "0"), @@ -90,7 +82,6 @@ func TestBuildRebalanceAllocations(t *testing.T) { {TickLower: 100, TickUpper: 100, Weight: mustDecimal(t, "1")}, }, amount: mustDecimal(t, "1"), - slippage: mustDecimal(t, "0"), wantErr: true, wantErrIs: liquidity.ErrInvalidTickRange, wantAmount: mustDecimal(t, "0"), @@ -101,7 +92,6 @@ func TestBuildRebalanceAllocations(t *testing.T) { {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "-1")}, }, amount: mustDecimal(t, "1"), - slippage: mustDecimal(t, "0"), wantErr: true, wantErrIs: liquidity.ErrInvalidWeight, wantAmount: mustDecimal(t, "0"), @@ -112,22 +102,10 @@ func TestBuildRebalanceAllocations(t *testing.T) { {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "1")}, }, amount: mustDecimal(t, "-1"), - slippage: mustDecimal(t, "0"), wantErr: true, wantErrIs: liquidity.ErrInvalidTotalAmount, wantAmount: mustDecimal(t, "0"), }, - { - name: "error - slippage > 1", - ranges: []liquidity.TickRangeWeight{ - {TickLower: 0, TickUpper: 10, Weight: mustDecimal(t, "1")}, - }, - amount: mustDecimal(t, "1"), - slippage: mustDecimal(t, "1.01"), - wantErr: true, - wantErrIs: liquidity.ErrInvalidSlippage, - wantAmount: mustDecimal(t, "0"), - }, { name: "error - total weight zero", ranges: []liquidity.TickRangeWeight{ @@ -135,7 +113,6 @@ func TestBuildRebalanceAllocations(t *testing.T) { {TickLower: 10, TickUpper: 20, Weight: mustDecimal(t, "0")}, }, amount: mustDecimal(t, "1"), - slippage: mustDecimal(t, "0"), wantErr: true, wantErrIs: liquidity.ErrZeroTotalWeight, wantAmount: mustDecimal(t, "0"), @@ -150,7 +127,7 @@ func TestBuildRebalanceAllocations(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := liquidity.BuildRebalanceAllocations(tt.ranges, tt.amount, tt.slippage) + got, err := liquidity.BuildRebalanceAllocations(tt.ranges, tt.amount) if err != nil { if !tt.wantErr { t.Fatalf("BuildRebalanceAllocations() failed: %v", err) @@ -171,7 +148,7 @@ func TestBuildRebalanceAllocations(t *testing.T) { t.Fatalf("BuildRebalanceAllocations() mismatch (-want +got):\n%s", diff) } - sum := decimal.NewFromInt(0) + sum := decimal.Zero for _, allocation := range got { sum = sum.Add(allocation.Amount) } diff --git a/internal/liquidity/uniswap.go b/internal/liquidity/uniswap.go index 205413b..b4219ce 100644 --- a/internal/liquidity/uniswap.go +++ b/internal/liquidity/uniswap.go @@ -13,5 +13,4 @@ type RebalanceAllocation struct { TickUpper int Weight decimal.Decimal Amount decimal.Decimal - AmountMin decimal.Decimal }