Skip to content

Commit

Permalink
Cleanup fee.staticCalculator (#3210)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Jul 22, 2024
1 parent 78b9f28 commit f4d8a3c
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 65 deletions.
2 changes: 1 addition & 1 deletion vms/platformvm/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func TestGetBalance(t *testing.T) {

feeCalculator, err := state.PickFeeCalculator(&service.vm.Config, service.vm.state)
require.NoError(err)
createSubnetFee, err := feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}})
createSubnetFee, err := feeCalculator.CalculateFee(&txs.CreateSubnetTx{})
require.NoError(err)

// Ensure GetStake is correct for each of the genesis validators
Expand Down
14 changes: 7 additions & 7 deletions vms/platformvm/txs/executor/staker_tx_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func verifyAddValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -258,7 +258,7 @@ func verifyAddSubnetValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -338,7 +338,7 @@ func verifyRemoveSubnetValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -458,7 +458,7 @@ func verifyAddDelegatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -580,7 +580,7 @@ func verifyAddPermissionlessValidatorTx(
copy(outs[len(tx.Outs):], tx.StakeOuts)

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -727,7 +727,7 @@ func verifyAddPermissionlessDelegatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -787,7 +787,7 @@ func verifyTransferSubnetOwnershipTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions vms/platformvm/txs/executor/standard_tx_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (e *StandardTxExecutor) CreateChainTx(tx *txs.CreateChainTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func (e *StandardTxExecutor) CreateSubnetTx(tx *txs.CreateSubnetTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func (e *StandardTxExecutor) ImportTx(tx *txs.ImportTx) error {
copy(ins[len(tx.Ins):], tx.ImportedInputs)

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -263,7 +263,7 @@ func (e *StandardTxExecutor) ExportTx(tx *txs.ExportTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -457,7 +457,7 @@ func (e *StandardTxExecutor) TransformSubnetTx(tx *txs.TransformSubnetTx) error
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func (e *StandardTxExecutor) BaseTx(tx *txs.BaseTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion vms/platformvm/txs/fee/calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import "github.com/ava-labs/avalanchego/vms/platformvm/txs"

// Calculator is the interfaces that any fee Calculator must implement
type Calculator interface {
CalculateFee(tx *txs.Tx) (uint64, error)
CalculateFee(tx txs.UnsignedTx) (uint64, error)
}
4 changes: 2 additions & 2 deletions vms/platformvm/txs/fee/calculator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ func TestTxFees(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uTx := tt.unsignedTx()
tx := tt.unsignedTx()
fc := NewStaticCalculator(feeTestsDefaultCfg, upgrades, tt.chainTime)
fee, err := fc.CalculateFee(&txs.Tx{Unsigned: uTx})
fee, err := fc.CalculateFee(tx)
require.NoError(t, err)
require.Equal(t, tt.expected, fee)
})
Expand Down
102 changes: 56 additions & 46 deletions vms/platformvm/txs/fee/static_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

var (
_ Calculator = (*staticCalculator)(nil)
_ txs.Visitor = (*staticCalculator)(nil)
_ txs.Visitor = (*staticVisitor)(nil)
)

func NewStaticCalculator(
Expand All @@ -22,115 +22,125 @@ func NewStaticCalculator(
chainTime time.Time,
) Calculator {
return &staticCalculator{
upgrades: upgradeTimes,
staticCfg: config,
time: chainTime,
upgrades: upgradeTimes,
config: config,
time: chainTime,
}
}

type staticCalculator struct {
// inputs
staticCfg StaticConfig
upgrades upgrade.Config
time time.Time
config StaticConfig
upgrades upgrade.Config
time time.Time
}

// outputs of visitor execution
fee uint64
func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) {
v := staticVisitor{
config: c.config,
upgrades: c.upgrades,
time: c.time,
}
err := tx.Visit(&v)
return v.fee, err
}

func (c *staticCalculator) CalculateFee(tx *txs.Tx) (uint64, error) {
c.fee = 0 // zero fee among different calculateFee invocations (unlike gas which gets cumulated)
err := tx.Unsigned.Visit(c)
return c.fee, err
type staticVisitor struct {
// inputs
config StaticConfig
upgrades upgrade.Config
time time.Time

// outputs
fee uint64
}

func (c *staticCalculator) AddValidatorTx(*txs.AddValidatorTx) error {
c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee
func (c *staticVisitor) AddValidatorTx(*txs.AddValidatorTx) error {
c.fee = c.config.AddPrimaryNetworkValidatorFee
return nil
}

func (c *staticCalculator) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error {
c.fee = c.staticCfg.AddSubnetValidatorFee
func (c *staticVisitor) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error {
c.fee = c.config.AddSubnetValidatorFee
return nil
}

func (c *staticCalculator) AddDelegatorTx(*txs.AddDelegatorTx) error {
c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee
func (c *staticVisitor) AddDelegatorTx(*txs.AddDelegatorTx) error {
c.fee = c.config.AddPrimaryNetworkDelegatorFee
return nil
}

func (c *staticCalculator) CreateChainTx(*txs.CreateChainTx) error {
func (c *staticVisitor) CreateChainTx(*txs.CreateChainTx) error {
if c.upgrades.IsApricotPhase3Activated(c.time) {
c.fee = c.staticCfg.CreateBlockchainTxFee
c.fee = c.config.CreateBlockchainTxFee
} else {
c.fee = c.staticCfg.CreateAssetTxFee
c.fee = c.config.CreateAssetTxFee
}
return nil
}

func (c *staticCalculator) CreateSubnetTx(*txs.CreateSubnetTx) error {
func (c *staticVisitor) CreateSubnetTx(*txs.CreateSubnetTx) error {
if c.upgrades.IsApricotPhase3Activated(c.time) {
c.fee = c.staticCfg.CreateSubnetTxFee
c.fee = c.config.CreateSubnetTxFee
} else {
c.fee = c.staticCfg.CreateAssetTxFee
c.fee = c.config.CreateAssetTxFee
}
return nil
}

func (c *staticCalculator) AdvanceTimeTx(*txs.AdvanceTimeTx) error {
func (c *staticVisitor) AdvanceTimeTx(*txs.AdvanceTimeTx) error {
c.fee = 0 // no fees
return nil
}

func (c *staticCalculator) RewardValidatorTx(*txs.RewardValidatorTx) error {
func (c *staticVisitor) RewardValidatorTx(*txs.RewardValidatorTx) error {
c.fee = 0 // no fees
return nil
}

func (c *staticCalculator) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) TransformSubnetTx(*txs.TransformSubnetTx) error {
c.fee = c.staticCfg.TransformSubnetTxFee
func (c *staticVisitor) TransformSubnetTx(*txs.TransformSubnetTx) error {
c.fee = c.config.TransformSubnetTxFee
return nil
}

func (c *staticCalculator) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error {
func (c *staticVisitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error {
if tx.Subnet != constants.PrimaryNetworkID {
c.fee = c.staticCfg.AddSubnetValidatorFee
c.fee = c.config.AddSubnetValidatorFee
} else {
c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee
c.fee = c.config.AddPrimaryNetworkValidatorFee
}
return nil
}

func (c *staticCalculator) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error {
func (c *staticVisitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error {
if tx.Subnet != constants.PrimaryNetworkID {
c.fee = c.staticCfg.AddSubnetDelegatorFee
c.fee = c.config.AddSubnetDelegatorFee
} else {
c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee
c.fee = c.config.AddPrimaryNetworkDelegatorFee
}
return nil
}

func (c *staticCalculator) BaseTx(*txs.BaseTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) BaseTx(*txs.BaseTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) ImportTx(*txs.ImportTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) ImportTx(*txs.ImportTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) ExportTx(*txs.ExportTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) ExportTx(*txs.ExportTx) error {
c.fee = c.config.TxFee
return nil
}
4 changes: 2 additions & 2 deletions vms/platformvm/txs/txstest/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func newContext(
) *builder.Context {
var (
feeCalculator = fee.NewStaticCalculator(cfg.StaticFeeConfig, cfg.UpgradeConfig, timestamp)
createSubnetFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}})
createChainFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateChainTx{}})
createSubnetFee, _ = feeCalculator.CalculateFee(&txs.CreateSubnetTx{})
createChainFee, _ = feeCalculator.CalculateFee(&txs.CreateChainTx{})
)

return &builder.Context{
Expand Down

0 comments on commit f4d8a3c

Please sign in to comment.