diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 3177f9877..1eb64b326 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -3,6 +3,9 @@ on: push: branches: - "develop" + pull_request: + branches: + - "feature/*" jobs: test: @@ -71,6 +74,7 @@ jobs: action: persist push-docker-build: + if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads/develop') runs-on: ubuntu-latest needs: test steps: diff --git a/README.md b/README.md index 2ecb6eb8c..a4180b3b0 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ docker exec -it razor-go razor setConfig --provider --alternatePr Example: ``` -$ ./razor setConfig --provider https://mainnet.skalenodes.com/v1/turbulent-unique-scheat --alternateProvider https://ce2m-skale.chainode.tech:10200/ --gasmultiplier 1 --buffer 20 --wait 30 --gasprice 0 --logLevel debug --gasLimit 2 --rpcTimeout 10 --httpTimeout 10 --logFileMaxSize 200 --logFileMaxBackups 52 --logFileMaxAge 365 +$ ./razor setConfig --provider https://mainnet.skalenodes.com/v1/turbulent-unique-scheat --alternateProvider https://ce2m-skale.chainode.tech:10200/ --gasmultiplier 1 --buffer 20 --wait 30 --gasprice 0 --logLevel debug --gasLimit 2 --rpcTimeout 10 --httpTimeout 10 --logFileMaxSize 200 --logFileMaxBackups 10 --logFileMaxAge 60 ``` Besides, setting these parameters in the config, you can use different values for these parameters in various commands. Just add the same flag to any command you want to use and the new config changes will appear for that command. diff --git a/accounts/accountUtils.go b/accounts/accountUtils.go index 11f8a5f10..e0d7e87df 100644 --- a/accounts/accountUtils.go +++ b/accounts/accountUtils.go @@ -2,56 +2,26 @@ package accounts import ( - "crypto/ecdsa" - "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/keystore" - "github.com/ethereum/go-ethereum/crypto" - "os" "razor/core/types" ) -//go:generate mockery --name AccountInterface --output ./mocks/ --case=underscore - -var AccountUtilsInterface AccountInterface - -type AccountInterface interface { - CreateAccount(path string, password string) accounts.Account - GetPrivateKeyFromKeystore(keystorePath string, password string) (*ecdsa.PrivateKey, error) - GetPrivateKey(address string, password string, keystorePath string) (*ecdsa.PrivateKey, error) - SignData(hash []byte, account types.Account, defaultPath string) ([]byte, error) - Accounts(path string) []accounts.Account - NewAccount(path string, passphrase string) (accounts.Account, error) - DecryptKey(jsonBytes []byte, password string) (*keystore.Key, error) - Sign(digestHash []byte, prv *ecdsa.PrivateKey) ([]byte, error) - ReadFile(filename string) ([]byte, error) -} - -type AccountUtils struct{} - -//This function returns all the accounts in form of array -func (accountUtils AccountUtils) Accounts(path string) []accounts.Account { - ks := keystore.NewKeyStore(path, keystore.StandardScryptN, keystore.StandardScryptP) - return ks.Accounts() -} - -//This function takes path and pass phrase as input and returns the new account -func (accountUtils AccountUtils) NewAccount(path string, passphrase string) (accounts.Account, error) { - ks := keystore.NewKeyStore(path, keystore.StandardScryptN, keystore.StandardScryptP) - accounts.NewManager(&accounts.Config{InsecureUnlockAllowed: false}, ks) - return ks.NewAccount(passphrase) -} - -//This function takes json bytes array and password as input and returns the decrypted key -func (accountUtils AccountUtils) DecryptKey(jsonBytes []byte, password string) (*keystore.Key, error) { - return keystore.DecryptKey(jsonBytes, password) +type AccountManager struct { + Keystore *keystore.KeyStore } -//This function takes hash in form of byte array and private key as input and returns signature as byte array -func (accountUtils AccountUtils) Sign(digestHash []byte, prv *ecdsa.PrivateKey) (sig []byte, err error) { - return crypto.Sign(digestHash, prv) +func NewAccountManager(keystorePath string) *AccountManager { + ks := keystore.NewKeyStore(keystorePath, keystore.StandardScryptN, keystore.StandardScryptP) + return &AccountManager{ + Keystore: ks, + } } -//This function takes name of the file as input and returns the file data as byte array -func (accountUtils AccountUtils) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) +// InitAccountStruct initializes an Account struct with provided details. +func InitAccountStruct(address, password string, accountManager types.AccountManagerInterface) types.Account { + return types.Account{ + Address: address, + Password: password, + AccountManager: accountManager, + } } diff --git a/accounts/accounts.go b/accounts/accounts.go index b907dfef9..4284ed6a3 100644 --- a/accounts/accounts.go +++ b/accounts/accounts.go @@ -5,7 +5,9 @@ import ( "crypto/ecdsa" "errors" "github.com/ethereum/go-ethereum/accounts" - "razor/core/types" + "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/ethereum/go-ethereum/crypto" + "os" "razor/logger" "razor/path" "strings" @@ -14,28 +16,46 @@ import ( var log = logger.NewLogger() //This function takes path and password as input and returns new account -func (AccountUtils) CreateAccount(keystorePath string, password string) accounts.Account { +func (am *AccountManager) CreateAccount(keystorePath string, password string) accounts.Account { if _, err := path.OSUtilsInterface.Stat(keystorePath); path.OSUtilsInterface.IsNotExist(err) { mkdirErr := path.OSUtilsInterface.Mkdir(keystorePath, 0700) if mkdirErr != nil { log.Fatal("Error in creating directory: ", mkdirErr) } } - newAcc, err := AccountUtilsInterface.NewAccount(keystorePath, password) + newAcc, err := am.NewAccount(password) if err != nil { log.Fatal("Error in creating account: ", err) } return newAcc } +//This function takes path and pass phrase as input and returns the new account +func (am *AccountManager) NewAccount(passphrase string) (accounts.Account, error) { + ks := am.Keystore + accounts.NewManager(&accounts.Config{InsecureUnlockAllowed: false}, ks) + return ks.NewAccount(passphrase) +} + +//This function takes address of account, password and keystore path as input and returns private key of account +func (am *AccountManager) GetPrivateKey(address string, password string) (*ecdsa.PrivateKey, error) { + allAccounts := am.Keystore.Accounts() + for _, account := range allAccounts { + if strings.EqualFold(account.Address.Hex(), address) { + return getPrivateKeyFromKeystore(account.URL.Path, password) + } + } + return nil, errors.New("no keystore file found") +} + //This function takes and path of keystore and password as input and returns private key of account -func (AccountUtils) GetPrivateKeyFromKeystore(keystorePath string, password string) (*ecdsa.PrivateKey, error) { - jsonBytes, err := AccountUtilsInterface.ReadFile(keystorePath) +func getPrivateKeyFromKeystore(keystoreFilePath string, password string) (*ecdsa.PrivateKey, error) { + jsonBytes, err := os.ReadFile(keystoreFilePath) if err != nil { log.Error("Error in reading keystore: ", err) return nil, err } - key, err := AccountUtilsInterface.DecryptKey(jsonBytes, password) + key, err := keystore.DecryptKey(jsonBytes, password) if err != nil { log.Error("Error in fetching private key: ", err) return nil, err @@ -43,22 +63,11 @@ func (AccountUtils) GetPrivateKeyFromKeystore(keystorePath string, password stri return key.PrivateKey, nil } -//This function takes address of account, password and keystore path as input and returns private key of account -func (AccountUtils) GetPrivateKey(address string, password string, keystorePath string) (*ecdsa.PrivateKey, error) { - allAccounts := AccountUtilsInterface.Accounts(keystorePath) - for _, account := range allAccounts { - if strings.EqualFold(account.Address.Hex(), address) { - return AccountUtilsInterface.GetPrivateKeyFromKeystore(account.URL.Path, password) - } - } - return nil, errors.New("no keystore file found") -} - //This function takes hash, account and path as input and returns the signed data as array of byte -func (AccountUtils) SignData(hash []byte, account types.Account, defaultPath string) ([]byte, error) { - privateKey, err := AccountUtilsInterface.GetPrivateKey(account.Address, account.Password, defaultPath) +func (am *AccountManager) SignData(hash []byte, address string, password string) ([]byte, error) { + privateKey, err := am.GetPrivateKey(address, password) if err != nil { return nil, err } - return AccountUtilsInterface.Sign(hash, privateKey) + return crypto.Sign(hash, privateKey) } diff --git a/accounts/accounts_test.go b/accounts/accounts_test.go index aa964bf47..986a78c5c 100644 --- a/accounts/accounts_test.go +++ b/accounts/accounts_test.go @@ -2,314 +2,193 @@ package accounts import ( "crypto/ecdsa" - "errors" - "github.com/ethereum/go-ethereum/accounts" - "github.com/ethereum/go-ethereum/accounts/keystore" - "github.com/ethereum/go-ethereum/common" - "github.com/magiconair/properties/assert" - "github.com/stretchr/testify/mock" - "io/fs" - "razor/accounts/mocks" - "razor/core/types" - "razor/path" - mocks1 "razor/path/mocks" + "encoding/hex" "reflect" "testing" ) -func TestCreateAccount(t *testing.T) { - var keystorePath string - var password string - var fileInfo fs.FileInfo - - type args struct { - account accounts.Account - accountErr error - statErr error - isNotExist bool - mkdirErr error - } - tests := []struct { - name string - args args - want accounts.Account - expectedFatal bool - }{ - { - name: "Test 1: When NewAccounts executes successfully", - args: args{ - account: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - }, - want: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - expectedFatal: false, - }, - { - name: "Test 2: When there is an error in getting new account", - args: args{ - accountErr: errors.New("account error"), - }, - want: accounts.Account{Address: common.HexToAddress("0x00")}, - expectedFatal: true, - }, - { - name: "Test 3: When keystore directory does not exists and mkdir creates it", - args: args{ - account: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - statErr: errors.New("not exists"), - isNotExist: true, - }, - want: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - expectedFatal: false, - }, - { - name: "Test 4: When keystore directory does not exists and there an error creating new one", - args: args{ - account: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - statErr: errors.New("not exists"), - isNotExist: true, - mkdirErr: errors.New("mkdir error"), - }, - want: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - expectedFatal: true, - }, - } - - defer func() { log.ExitFunc = nil }() - var fatal bool - log.ExitFunc = func(int) { fatal = true } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - accountsMock := new(mocks.AccountInterface) - osMock := new(mocks1.OSInterface) - - path.OSUtilsInterface = osMock - AccountUtilsInterface = accountsMock - - accountsMock.On("NewAccount", mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(tt.args.account, tt.args.accountErr) - osMock.On("Stat", mock.AnythingOfType("string")).Return(fileInfo, tt.args.statErr) - osMock.On("IsNotExist", mock.Anything).Return(tt.args.isNotExist) - osMock.On("Mkdir", mock.Anything, mock.Anything).Return(tt.args.mkdirErr) - - accountUtils := AccountUtils{} - fatal = false - got := accountUtils.CreateAccount(keystorePath, password) - if tt.expectedFatal { - assert.Equal(t, tt.expectedFatal, fatal) - } - if got.Address != tt.want.Address { - t.Errorf("New address created, got = %v, want %v", got, tt.want.Address) - } - }) - } +func privateKeyToHex(privateKey *ecdsa.PrivateKey) string { + return hex.EncodeToString(privateKey.D.Bytes()) } -func TestGetPrivateKeyFromKeystore(t *testing.T) { - var password string - var keystorePath string - var privateKey *ecdsa.PrivateKey - var jsonBytes []byte +func Test_getPrivateKeyFromKeystore(t *testing.T) { + password := "Razor@123" type args struct { - jsonBytes []byte - jsonBytesErr error - key *keystore.Key - keyErr error + keystoreFilePath string + password string } tests := []struct { name string args args - want *ecdsa.PrivateKey + want string wantErr bool }{ { - name: "Test 1: When GetPrivateKey function executes successfully", + name: "Test 1: When keystore file is present and getPrivateKeyFromKeystore function executes successfully", args: args{ - jsonBytes: jsonBytes, - key: &keystore.Key{ - PrivateKey: privateKey, - }, + keystoreFilePath: "test_accounts/UTC--2024-03-20T07-03-56.358521000Z--911654feb423363fb771e04e18d1e7325ae10a91", + password: password, }, - want: privateKey, + want: "b110b1f06b7b64323a6fb768ceab966abe9f65f4e6ab3c39382bd446122f7b01", wantErr: false, }, { - name: "Test 2: When there is an error in reading data from file", + name: "Test 2: When there is no keystore file present at the desired path", args: args{ - jsonBytesErr: errors.New("error in reading data"), - key: &keystore.Key{ - PrivateKey: nil, - }, + keystoreFilePath: "test_accounts/UTC--2024-03-20T07-03-56.358521000Z--211654feb423363fb771e04e18d1e7325ae10a91", + password: password, }, - want: nil, + want: "", wantErr: true, }, { - name: "Test 3: When there is an error in fetching private key", + name: "Test 3: When password is incorrect for the desired keystore file", args: args{ - jsonBytes: jsonBytes, - key: &keystore.Key{ - PrivateKey: nil, - }, - keyErr: errors.New("private key error"), + keystoreFilePath: "test_accounts/UTC--2024-03-20T07-03-56.358521000Z--911654feb423363fb771e04e18d1e7325ae10a91", + password: "Razor@456", }, - want: privateKey, + want: "", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - accountsMock := new(mocks.AccountInterface) - AccountUtilsInterface = accountsMock - - accountsMock.On("ReadFile", mock.AnythingOfType("string")).Return(tt.args.jsonBytes, tt.args.jsonBytesErr) - accountsMock.On("DecryptKey", mock.Anything, mock.AnythingOfType("string")).Return(tt.args.key, tt.args.keyErr) - - accountUtils := &AccountUtils{} - got, err := accountUtils.GetPrivateKeyFromKeystore(keystorePath, password) - if got != tt.want { - t.Errorf("Private key from GetPrivateKey, got = %v, want %v", got, tt.want) - } + gotPrivateKey, err := getPrivateKeyFromKeystore(tt.args.keystoreFilePath, tt.args.password) if (err != nil) != tt.wantErr { t.Errorf("GetPrivateKeyFromKeystore() error = %v, wantErr %v", err, tt.wantErr) return } + + // If there's no error and a private key is expected, compare the keys + if !tt.wantErr && tt.want != "" { + gotPrivateKeyHex := privateKeyToHex(gotPrivateKey) + if gotPrivateKeyHex != tt.want { + t.Errorf("GetPrivateKey() got private key = %v, want %v", gotPrivateKeyHex, tt.want) + } + } }) } } func TestGetPrivateKey(t *testing.T) { - var password string - var keystorePath string - var privateKey *ecdsa.PrivateKey - - accountsList := []accounts.Account{ - {Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - {Address: common.HexToAddress("0x000000000000000000000000000000000000dea2"), - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, - } + password := "Razor@123" + keystoreDirPath := "test_accounts" type args struct { - address string - accounts []accounts.Account - privateKey *ecdsa.PrivateKey + address string + password string } tests := []struct { name string args args - want *ecdsa.PrivateKey + want string wantErr bool }{ { - name: "Test 1: When input address is present in accountsList", + name: "Test 1: When input address with correct password is present in keystore directory", + args: args{ + address: "0x911654feb423363fb771e04e18d1e7325ae10a91", + password: password, + }, + want: "b110b1f06b7b64323a6fb768ceab966abe9f65f4e6ab3c39382bd446122f7b01", + wantErr: false, + }, + { + name: "Test 2: When input upper case address with correct password is present in keystore directory", args: args{ - address: "0x000000000000000000000000000000000000dea1", - accounts: accountsList, - privateKey: privateKey, + address: "0x2F5F59615689B706B6AD13FD03343DCA28784989", + password: password, }, - want: privateKey, + want: "726223b8b95628edef6cf2774ddde39fb3ea482949c8847fabf74cd994219b50", wantErr: false, }, { - name: "Test 2: When input address is not present in accountsList", + name: "Test 3: When provided address is not present in keystore directory", args: args{ - address: "0x000000000000000000000000000000000000dea3", - accounts: accountsList, - privateKey: privateKey, + address: "0x911654feb423363fb771e04e18d1e7325ae10a91_not_present", }, - want: nil, + want: "", + wantErr: true, + }, + { + name: "Test 4: When input address with incorrect password is present in keystore directory", + args: args{ + address: "0x911654feb423363fb771e04e18d1e7325ae10a91", + password: "incorrect password", + }, + want: "", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - accountsMock := new(mocks.AccountInterface) - AccountUtilsInterface = accountsMock - - accountsMock.On("Accounts", mock.AnythingOfType("string")).Return(tt.args.accounts) - accountsMock.On("GetPrivateKeyFromKeystore", mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(tt.args.privateKey, nil) - - accountUtils := &AccountUtils{} - got, err := accountUtils.GetPrivateKey(tt.args.address, password, keystorePath) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetPrivateKey() got = %v, want %v", got, tt.want) - } + am := NewAccountManager(keystoreDirPath) + gotPrivateKey, err := am.GetPrivateKey(tt.args.address, tt.args.password) if (err != nil) != tt.wantErr { t.Errorf("GetPrivateKey() error = %v, wantErr %v", err, tt.wantErr) return } + + // If there's no error and a private key is expected, compare the keys + if !tt.wantErr && tt.want != "" { + gotPrivateKeyHex := privateKeyToHex(gotPrivateKey) + if gotPrivateKeyHex != tt.want { + t.Errorf("GetPrivateKey() got private key = %v, want %v", gotPrivateKeyHex, tt.want) + } + } }) } } func TestSignData(t *testing.T) { - var hash []byte - var account types.Account - var defaultPath string - var privateKey *ecdsa.PrivateKey - var signature []byte + password := "Razor@123" + + hexHash := "a3b8d42c7015c1e9354f8b9c2161d9b2e1ad89e6b6c7a9610e029fd7afec27ae" + hashBytes, err := hex.DecodeString(hexHash) + if err != nil { + log.Fatal("Failed to decode hex string") + } type args struct { - privateKey *ecdsa.PrivateKey - privateKeyErr error - signature []byte - signatureErr error + address string + password string + hash []byte } tests := []struct { name string args args - want []byte + want string wantErr bool }{ { name: "Test 1: When Sign function returns no error", args: args{ - privateKey: privateKey, - signature: signature, - signatureErr: nil, + address: "0x911654feb423363fb771e04e18d1e7325ae10a91", + password: password, + hash: hashBytes, }, - want: signature, + want: "f14cf1b8c9486777e4280b4da6cb7c314d5c19b7e30d32a46f83767a1946e35a39f6941df71375d7ffceaddac81e2454e9a129896803d02f633eb78ab7883ff200", wantErr: false, }, { name: "Test 2: When there is an error in getting private key", args: args{ - privateKeyErr: errors.New("privateKey"), + address: "0x_invalid_address", + password: password, + hash: hashBytes, }, - want: nil, + want: hex.EncodeToString(nil), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - accountsMock := new(mocks.AccountInterface) - AccountUtilsInterface = accountsMock - - accountsMock.On("GetPrivateKey", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(tt.args.privateKey, tt.args.privateKeyErr) - accountsMock.On("Sign", mock.Anything, mock.Anything).Return(tt.args.signature, tt.args.signatureErr) - - accountUtils := &AccountUtils{} - - got, err := accountUtils.SignData(hash, account, defaultPath) - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Sign() got = %v, want %v", got, tt.want) + am := NewAccountManager("test_accounts") + got, err := am.SignData(tt.args.hash, tt.args.address, tt.args.password) + if !reflect.DeepEqual(hex.EncodeToString(got), tt.want) { + t.Errorf("Sign() got = %v, want %v", hex.EncodeToString(got), tt.want) } if (err != nil) != tt.wantErr { diff --git a/accounts/mocks/account_interface.go b/accounts/mocks/account_interface.go deleted file mode 100644 index e8e006441..000000000 --- a/accounts/mocks/account_interface.go +++ /dev/null @@ -1,224 +0,0 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. - -package mocks - -import ( - ecdsa "crypto/ecdsa" - - accounts "github.com/ethereum/go-ethereum/accounts" - - keystore "github.com/ethereum/go-ethereum/accounts/keystore" - - mock "github.com/stretchr/testify/mock" - - types "razor/core/types" -) - -// AccountInterface is an autogenerated mock type for the AccountInterface type -type AccountInterface struct { - mock.Mock -} - -// Accounts provides a mock function with given fields: path -func (_m *AccountInterface) Accounts(path string) []accounts.Account { - ret := _m.Called(path) - - var r0 []accounts.Account - if rf, ok := ret.Get(0).(func(string) []accounts.Account); ok { - r0 = rf(path) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]accounts.Account) - } - } - - return r0 -} - -// CreateAccount provides a mock function with given fields: path, password -func (_m *AccountInterface) CreateAccount(path string, password string) accounts.Account { - ret := _m.Called(path, password) - - var r0 accounts.Account - if rf, ok := ret.Get(0).(func(string, string) accounts.Account); ok { - r0 = rf(path, password) - } else { - r0 = ret.Get(0).(accounts.Account) - } - - return r0 -} - -// DecryptKey provides a mock function with given fields: jsonBytes, password -func (_m *AccountInterface) DecryptKey(jsonBytes []byte, password string) (*keystore.Key, error) { - ret := _m.Called(jsonBytes, password) - - var r0 *keystore.Key - if rf, ok := ret.Get(0).(func([]byte, string) *keystore.Key); ok { - r0 = rf(jsonBytes, password) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*keystore.Key) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func([]byte, string) error); ok { - r1 = rf(jsonBytes, password) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetPrivateKey provides a mock function with given fields: address, password, keystorePath -func (_m *AccountInterface) GetPrivateKey(address string, password string, keystorePath string) (*ecdsa.PrivateKey, error) { - ret := _m.Called(address, password, keystorePath) - - var r0 *ecdsa.PrivateKey - if rf, ok := ret.Get(0).(func(string, string, string) *ecdsa.PrivateKey); ok { - r0 = rf(address, password, keystorePath) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*ecdsa.PrivateKey) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(string, string, string) error); ok { - r1 = rf(address, password, keystorePath) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetPrivateKeyFromKeystore provides a mock function with given fields: keystorePath, password -func (_m *AccountInterface) GetPrivateKeyFromKeystore(keystorePath string, password string) (*ecdsa.PrivateKey, error) { - ret := _m.Called(keystorePath, password) - - var r0 *ecdsa.PrivateKey - if rf, ok := ret.Get(0).(func(string, string) *ecdsa.PrivateKey); ok { - r0 = rf(keystorePath, password) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*ecdsa.PrivateKey) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(string, string) error); ok { - r1 = rf(keystorePath, password) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewAccount provides a mock function with given fields: path, passphrase -func (_m *AccountInterface) NewAccount(path string, passphrase string) (accounts.Account, error) { - ret := _m.Called(path, passphrase) - - var r0 accounts.Account - if rf, ok := ret.Get(0).(func(string, string) accounts.Account); ok { - r0 = rf(path, passphrase) - } else { - r0 = ret.Get(0).(accounts.Account) - } - - var r1 error - if rf, ok := ret.Get(1).(func(string, string) error); ok { - r1 = rf(path, passphrase) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ReadFile provides a mock function with given fields: filename -func (_m *AccountInterface) ReadFile(filename string) ([]byte, error) { - ret := _m.Called(filename) - - var r0 []byte - if rf, ok := ret.Get(0).(func(string) []byte); ok { - r0 = rf(filename) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(filename) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Sign provides a mock function with given fields: digestHash, prv -func (_m *AccountInterface) Sign(digestHash []byte, prv *ecdsa.PrivateKey) ([]byte, error) { - ret := _m.Called(digestHash, prv) - - var r0 []byte - if rf, ok := ret.Get(0).(func([]byte, *ecdsa.PrivateKey) []byte); ok { - r0 = rf(digestHash, prv) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func([]byte, *ecdsa.PrivateKey) error); ok { - r1 = rf(digestHash, prv) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SignData provides a mock function with given fields: hash, account, defaultPath -func (_m *AccountInterface) SignData(hash []byte, account types.Account, defaultPath string) ([]byte, error) { - ret := _m.Called(hash, account, defaultPath) - - var r0 []byte - if rf, ok := ret.Get(0).(func([]byte, types.Account, string) []byte); ok { - r0 = rf(hash, account, defaultPath) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func([]byte, types.Account, string) error); ok { - r1 = rf(hash, account, defaultPath) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type mockConstructorTestingTNewAccountInterface interface { - mock.TestingT - Cleanup(func()) -} - -// NewAccountInterface creates a new instance of AccountInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewAccountInterface(t mockConstructorTestingTNewAccountInterface) *AccountInterface { - mock := &AccountInterface{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/accounts/mocks/account_manager_interface.go b/accounts/mocks/account_manager_interface.go new file mode 100644 index 000000000..0bc33fab1 --- /dev/null +++ b/accounts/mocks/account_manager_interface.go @@ -0,0 +1,120 @@ +// Code generated by mockery v2.30.1. DO NOT EDIT. + +package mocks + +import ( + ecdsa "crypto/ecdsa" + + accounts "github.com/ethereum/go-ethereum/accounts" + + mock "github.com/stretchr/testify/mock" +) + +// AccountManagerInterface is an autogenerated mock type for the AccountManagerInterface type +type AccountManagerInterface struct { + mock.Mock +} + +// CreateAccount provides a mock function with given fields: keystorePath, password +func (_m *AccountManagerInterface) CreateAccount(keystorePath string, password string) accounts.Account { + ret := _m.Called(keystorePath, password) + + var r0 accounts.Account + if rf, ok := ret.Get(0).(func(string, string) accounts.Account); ok { + r0 = rf(keystorePath, password) + } else { + r0 = ret.Get(0).(accounts.Account) + } + + return r0 +} + +// GetPrivateKey provides a mock function with given fields: address, password +func (_m *AccountManagerInterface) GetPrivateKey(address string, password string) (*ecdsa.PrivateKey, error) { + ret := _m.Called(address, password) + + var r0 *ecdsa.PrivateKey + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*ecdsa.PrivateKey, error)); ok { + return rf(address, password) + } + if rf, ok := ret.Get(0).(func(string, string) *ecdsa.PrivateKey); ok { + r0 = rf(address, password) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*ecdsa.PrivateKey) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(address, password) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewAccount provides a mock function with given fields: passphrase +func (_m *AccountManagerInterface) NewAccount(passphrase string) (accounts.Account, error) { + ret := _m.Called(passphrase) + + var r0 accounts.Account + var r1 error + if rf, ok := ret.Get(0).(func(string) (accounts.Account, error)); ok { + return rf(passphrase) + } + if rf, ok := ret.Get(0).(func(string) accounts.Account); ok { + r0 = rf(passphrase) + } else { + r0 = ret.Get(0).(accounts.Account) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(passphrase) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SignData provides a mock function with given fields: hash, address, password +func (_m *AccountManagerInterface) SignData(hash []byte, address string, password string) ([]byte, error) { + ret := _m.Called(hash, address, password) + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func([]byte, string, string) ([]byte, error)); ok { + return rf(hash, address, password) + } + if rf, ok := ret.Get(0).(func([]byte, string, string) []byte); ok { + r0 = rf(hash, address, password) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func([]byte, string, string) error); ok { + r1 = rf(hash, address, password) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewAccountManagerInterface creates a new instance of AccountManagerInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAccountManagerInterface(t interface { + mock.TestingT + Cleanup(func()) +}) *AccountManagerInterface { + mock := &AccountManagerInterface{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/accounts/test_accounts/UTC--2024-03-20T07-03-56.358521000Z--911654feb423363fb771e04e18d1e7325ae10a91 b/accounts/test_accounts/UTC--2024-03-20T07-03-56.358521000Z--911654feb423363fb771e04e18d1e7325ae10a91 new file mode 100644 index 000000000..bc50d72f6 --- /dev/null +++ b/accounts/test_accounts/UTC--2024-03-20T07-03-56.358521000Z--911654feb423363fb771e04e18d1e7325ae10a91 @@ -0,0 +1 @@ +{"address":"911654feb423363fb771e04e18d1e7325ae10a91","crypto":{"cipher":"aes-128-ctr","ciphertext":"032e882238c605aa6ede0c54658b1f26a8800e4b41c67349159236e7ffa76955","cipherparams":{"iv":"12cae1b8475b00f92e2eac08e08a6a39"},"kdf":"scrypt","kdfparams":{"dklen":32,"n":262144,"p":1,"r":8,"salt":"bacdad2b595bda0ed52706842b9a1d20ce64210eeadcf7302e6ba9e2ddc22706"},"mac":"722ea7fc4cee367e92bc403a5371aab1e78e66e3876d00d3a40830130c443101"},"id":"e0b6a612-8877-4400-813a-ef8a627d33d8","version":3} \ No newline at end of file diff --git a/accounts/test_accounts/UTC--2024-03-20T07-04-11.601622000Z--2f5f59615689b706b6ad13fd03343dca28784989 b/accounts/test_accounts/UTC--2024-03-20T07-04-11.601622000Z--2f5f59615689b706b6ad13fd03343dca28784989 new file mode 100644 index 000000000..c773f086e --- /dev/null +++ b/accounts/test_accounts/UTC--2024-03-20T07-04-11.601622000Z--2f5f59615689b706b6ad13fd03343dca28784989 @@ -0,0 +1 @@ +{"address":"2f5f59615689b706b6ad13fd03343dca28784989","crypto":{"cipher":"aes-128-ctr","ciphertext":"e46770162aa3d74c00f1b7dff2a6e255743e61483d3f6a4c266c2697ac05aed3","cipherparams":{"iv":"8941588ce2540c7bef63f0df4a67e6cb"},"kdf":"scrypt","kdfparams":{"dklen":32,"n":262144,"p":1,"r":8,"salt":"e106a7ef7916893083977dbe18cd5edf57491ba21a13adc2db7b79c7c8b85ded"},"mac":"d30b6bf80ea2ddbaca14f7635596e56ac150bda4416e9dc915117b8574cda2ba"},"id":"0accaaa4-2e5c-451b-b813-6551f9c01511","version":3} \ No newline at end of file diff --git a/addresses.json b/addresses.json deleted file mode 100644 index 5226f0726..000000000 --- a/addresses.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "Governance": "0xD2A74B7A962FFc85827da0124A4278e731D15464", - "BlockManager": "0x096e44B0d8b68376C8Efe40F28C3857951f03069", - "CollectionManager": "0x3b76eB8c0282dAf531D7C507E4f3143A9A9c38b1", - "StakeManager": "0x9f55a2C6C1F1Be8B01562cEae2df2F22931C7a46", - "RewardManager": "0x07875369943951b1Af9c37f4ea137dcED9d9181d", - "VoteManager": "0x11995b74D6d07a6Edc05653a71F3e8B3354caBF0", - "Delegator": "0xe295863DF95AaAeC66E7de11D3aD0C35698d0fE9", - "RAZOR": "0x4500E10fEb89e46E9fb642D0c62b1a761278155D", - "StakedTokenFactory": "0xe20e11687F269fE9e356da92C4B15aBF98BbC9ff", - "RandomNoManager": "0x31463bC4D5E67Bca623fFc6152D253Ea17216fA9" -} \ No newline at end of file diff --git a/addresses/mainnet.json b/addresses/mainnet.json index 84568c78b..81462089a 100644 --- a/addresses/mainnet.json +++ b/addresses/mainnet.json @@ -9,4 +9,4 @@ "RAZOR": "0xcbf70914Fae03B3acB91E953De60CfDAaCA8145f", "StakedTokenFactory": "0xEffA78888Dc1b6033286E5dF9b170bc5223178AB", "RandomNoManager": "0xC6eF45F5Add040800D30FE6dEe01b4EBC4BfC467" - } \ No newline at end of file +} \ No newline at end of file diff --git a/addresses/testnet.json b/addresses/testnet.json index 1e536a19a..a70785723 100644 --- a/addresses/testnet.json +++ b/addresses/testnet.json @@ -1,12 +1,13 @@ { - "Governance": "0xD2A74B7A962FFc85827da0124A4278e731D15464", - "BlockManager": "0x096e44B0d8b68376C8Efe40F28C3857951f03069", - "CollectionManager": "0x3b76eB8c0282dAf531D7C507E4f3143A9A9c38b1", - "StakeManager": "0x9f55a2C6C1F1Be8B01562cEae2df2F22931C7a46", - "RewardManager": "0x07875369943951b1Af9c37f4ea137dcED9d9181d", - "VoteManager": "0x11995b74D6d07a6Edc05653a71F3e8B3354caBF0", - "Delegator": "0xe295863DF95AaAeC66E7de11D3aD0C35698d0fE9", - "RAZOR": "0x4500E10fEb89e46E9fb642D0c62b1a761278155D", - "StakedTokenFactory": "0xe20e11687F269fE9e356da92C4B15aBF98BbC9ff", - "RandomNoManager": "0x31463bC4D5E67Bca623fFc6152D253Ea17216fA9" - } \ No newline at end of file + "Governance": "0x0b9CC11E7f7D3D4f9bBc0cCaB85b73B96C322c78", + "BlockManager": "0x076df1c2d81C40D30DF115Ab68A13251fBD3FFA4", + "CollectionManager": "0x83f4D7ad6dD366c3F04Af45a1437c77636b03388", + "StakeManager": "0xf99a873a6afdF1b24388ac02ea0f1CFE3c70A80b", + "RewardManager": "0x522d2A51639332388dA4788DF59fB4E598278fAd", + "VoteManager": "0x2399D4d92b4D8762971605d0bC9597F780852CC4", + "Delegator": "0x81c72fB60d19Bfc3A0adDfE394501150d290cd66", + "RAZOR": "0x504C6635af1e7E6fdc586c8A89973783b9281A77", + "StakedTokenFactory": "0xf271bd91104946Cb8e43BC0e83423ed073ab136e", + "RandomNoManager": "0x2dc9E4663675c933f78D97cDD44463ee9C43144c", + "DelegatorV2": "0x4e9E5AE58Fb6d7Ca5b79A196b8d0FCF7CA15C100" +} \ No newline at end of file diff --git a/block/block.go b/block/block.go index bd12e70d2..ea896d420 100644 --- a/block/block.go +++ b/block/block.go @@ -32,9 +32,9 @@ func CalculateLatestBlock(client *ethclient.Client) { latestHeader, err := client.HeaderByNumber(context.Background(), nil) if err != nil { logrus.Error("CalculateBlockNumber: Error in fetching block: ", err) - continue + } else { + SetLatestBlock(latestHeader) } - SetLatestBlock(latestHeader) } time.Sleep(time.Second * time.Duration(core.BlockNumberInterval)) } diff --git a/cache/collectionsCache.go b/cache/collectionsCache.go new file mode 100644 index 000000000..c27064b64 --- /dev/null +++ b/cache/collectionsCache.go @@ -0,0 +1,35 @@ +package cache + +import ( + "razor/pkg/bindings" + "sync" +) + +// CollectionsCache struct to hold collection cache and associated mutex +type CollectionsCache struct { + Collections map[uint16]bindings.StructsCollection + Mu sync.RWMutex +} + +// NewCollectionsCache creates a new instance of CollectionsCache +func NewCollectionsCache() *CollectionsCache { + return &CollectionsCache{ + Collections: make(map[uint16]bindings.StructsCollection), + Mu: sync.RWMutex{}, + } +} + +func (c *CollectionsCache) GetCollection(collectionId uint16) (bindings.StructsCollection, bool) { + c.Mu.RLock() + defer c.Mu.RUnlock() + + collection, exists := c.Collections[collectionId] + return collection, exists +} + +func (c *CollectionsCache) UpdateCollection(collectionId uint16, updatedCollection bindings.StructsCollection) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.Collections[collectionId] = updatedCollection +} diff --git a/cache/jobsCache.go b/cache/jobsCache.go new file mode 100644 index 000000000..a13f1e0cd --- /dev/null +++ b/cache/jobsCache.go @@ -0,0 +1,35 @@ +package cache + +import ( + "razor/pkg/bindings" + "sync" +) + +// JobsCache struct to hold job cache and associated mutex +type JobsCache struct { + Jobs map[uint16]bindings.StructsJob + Mu sync.RWMutex +} + +// NewJobsCache creates a new instance of JobsCache +func NewJobsCache() *JobsCache { + return &JobsCache{ + Jobs: make(map[uint16]bindings.StructsJob), + Mu: sync.RWMutex{}, + } +} + +func (j *JobsCache) GetJob(jobId uint16) (bindings.StructsJob, bool) { + j.Mu.RLock() + defer j.Mu.RUnlock() + + job, exists := j.Jobs[jobId] + return job, exists +} + +func (j *JobsCache) UpdateJob(jobId uint16, updatedJob bindings.StructsJob) { + j.Mu.Lock() + defer j.Mu.Unlock() + + j.Jobs[jobId] = updatedJob +} diff --git a/cmd/addStake.go b/cmd/addStake.go index 9ae55fcd2..b8765442d 100644 --- a/cmd/addStake.go +++ b/cmd/addStake.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -49,7 +50,12 @@ func (*UtilsStruct) ExecuteStake(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) balance, err := razorUtils.FetchBalance(client, address) @@ -87,12 +93,11 @@ func (*UtilsStruct) ExecuteStake(flagSet *pflag.FlagSet) { } txnArgs := types.TransactionOptions{ - Client: client, - AccountAddress: address, - Password: password, - Amount: valueInWei, - ChainId: core.ChainId, - Config: config, + Client: client, + Amount: valueInWei, + ChainId: core.ChainId, + Config: config, + Account: account, } log.Debug("ExecuteStake: Calling Approve() for amount: ", txnArgs.Amount) diff --git a/cmd/addStake_test.go b/cmd/addStake_test.go index 5fcba5592..df21fd316 100644 --- a/cmd/addStake_test.go +++ b/cmd/addStake_test.go @@ -1,17 +1,14 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/pkg/bindings" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -20,17 +17,12 @@ import ( ) func TestStakeCoins(t *testing.T) { - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31337)) - txnArgs := types.TransactionOptions{ Amount: big.NewInt(10000), } type args struct { txnArgs types.TransactionOptions - txnOpts *bind.TransactOpts epoch uint32 getEpochErr error stakeTxn *Types.Transaction @@ -49,7 +41,6 @@ func TestStakeCoins(t *testing.T) { txnArgs: types.TransactionOptions{ Amount: big.NewInt(1000), }, - txnOpts: txnOpts, epoch: 2, getEpochErr: nil, stakeTxn: &Types.Transaction{}, @@ -65,7 +56,6 @@ func TestStakeCoins(t *testing.T) { txnArgs: types.TransactionOptions{ Amount: big.NewInt(1000), }, - txnOpts: txnOpts, epoch: 2, getEpochErr: errors.New("waitForAppropriateState error"), stakeTxn: &Types.Transaction{}, @@ -81,7 +71,6 @@ func TestStakeCoins(t *testing.T) { txnArgs: types.TransactionOptions{ Amount: big.NewInt(1000), }, - txnOpts: txnOpts, epoch: 2, getEpochErr: nil, stakeTxn: &Types.Transaction{}, @@ -96,7 +85,7 @@ func TestStakeCoins(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.getEpochErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) stakeManagerMock.On("Stake", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.stakeTxn, tt.args.stakeErr) @@ -330,7 +319,8 @@ func TestExecuteStake(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(nil) diff --git a/cmd/approve.go b/cmd/approve.go index 1a8cf1b8a..76f5e991f 100644 --- a/cmd/approve.go +++ b/cmd/approve.go @@ -11,7 +11,7 @@ import ( //This function approves the transaction if the user has sufficient balance otherwise it fails the transaction func (*UtilsStruct) Approve(txnArgs types.TransactionOptions) (common.Hash, error) { opts := razorUtils.GetOptions() - allowance, err := tokenManagerUtils.Allowance(txnArgs.Client, &opts, common.HexToAddress(txnArgs.AccountAddress), common.HexToAddress(core.StakeManagerAddress)) + allowance, err := tokenManagerUtils.Allowance(txnArgs.Client, &opts, common.HexToAddress(txnArgs.Account.Address), common.HexToAddress(core.StakeManagerAddress)) if err != nil { return core.NilHash, err } diff --git a/cmd/approve_test.go b/cmd/approve_test.go index 6dbce4ad5..218c9bae4 100644 --- a/cmd/approve_test.go +++ b/cmd/approve_test.go @@ -2,13 +2,10 @@ package cmd import ( "context" - "crypto/ecdsa" - "crypto/rand" "errors" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/mock" "math/big" "razor/core" @@ -17,14 +14,9 @@ import ( ) func TestApprove(t *testing.T) { - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { txnArgs types.TransactionOptions callOpts bind.CallOpts - transactOpts *bind.TransactOpts allowanceAmount *big.Int allowanceError error approveTxn *Types.Transaction @@ -50,7 +42,6 @@ func TestApprove(t *testing.T) { BlockNumber: big.NewInt(1), Context: context.Background(), }, - transactOpts: txnOpts, allowanceAmount: big.NewInt(0), allowanceError: nil, approveTxn: &Types.Transaction{}, @@ -72,7 +63,6 @@ func TestApprove(t *testing.T) { BlockNumber: big.NewInt(1), Context: context.Background(), }, - transactOpts: txnOpts, allowanceAmount: big.NewInt(10000), allowanceError: nil, approveTxn: &Types.Transaction{}, @@ -94,7 +84,6 @@ func TestApprove(t *testing.T) { BlockNumber: big.NewInt(1), Context: context.Background(), }, - transactOpts: txnOpts, allowanceAmount: big.NewInt(0), allowanceError: errors.New("allowance error"), approveTxn: &Types.Transaction{}, @@ -117,7 +106,6 @@ func TestApprove(t *testing.T) { BlockNumber: big.NewInt(1), Context: context.Background(), }, - transactOpts: txnOpts, allowanceAmount: big.NewInt(0), allowanceError: nil, approveTxn: &Types.Transaction{}, @@ -133,7 +121,7 @@ func TestApprove(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetOptions").Return(tt.args.callOpts) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) tokenManagerMock.On("Allowance", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.allowanceAmount, tt.args.allowanceError) tokenManagerMock.On("Approve", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.approveTxn, tt.args.approveError) diff --git a/cmd/claimBounty.go b/cmd/claimBounty.go index e4513cd68..0e91572bd 100644 --- a/cmd/claimBounty.go +++ b/cmd/claimBounty.go @@ -5,6 +5,7 @@ import ( "errors" "math/big" "os" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -53,7 +54,12 @@ func (*UtilsStruct) ExecuteClaimBounty(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) if razorUtils.IsFlagPassed("bountyId") { @@ -62,12 +68,10 @@ func (*UtilsStruct) ExecuteClaimBounty(flagSet *pflag.FlagSet) { log.Debug("ExecuteClaimBounty: BountyId: ", bountyId) redeemBountyInput := types.RedeemBountyInput{ - Address: address, - Password: password, BountyId: bountyId, + Account: account, } - log.Debugf("ExecuteClaimBounty: Calling ClaimBounty() with arguments redeem bounty input: %+v", redeemBountyInput) txn, err := cmdUtils.ClaimBounty(config, client, redeemBountyInput) utils.CheckError("ClaimBounty error: ", err) @@ -77,10 +81,7 @@ func (*UtilsStruct) ExecuteClaimBounty(flagSet *pflag.FlagSet) { } } else { log.Debug("ExecuteClaimBounty: Calling HandleClaimBounty()") - err := cmdUtils.HandleClaimBounty(client, config, types.Account{ - Address: address, - Password: password, - }) + err := cmdUtils.HandleClaimBounty(client, config, account) utils.CheckError("HandleClaimBounty error: ", err) } @@ -114,8 +115,7 @@ func (*UtilsStruct) HandleClaimBounty(client *ethclient.Client, config types.Con log.Info("Claiming bounty for bountyId ", disputeData.BountyIdQueue[length-1]) redeemBountyInput := types.RedeemBountyInput{ BountyId: disputeData.BountyIdQueue[length-1], - Address: account.Address, - Password: account.Password, + Account: account, } log.Debugf("HandleClaimBounty: Calling ClaimBounty() with arguments redeemBountyInput: %+v", redeemBountyInput) claimBountyTxn, err := cmdUtils.ClaimBounty(config, client, redeemBountyInput) @@ -147,8 +147,7 @@ func (*UtilsStruct) HandleClaimBounty(client *ethclient.Client, config types.Con func (*UtilsStruct) ClaimBounty(config types.Configurations, client *ethclient.Client, redeemBountyInput types.RedeemBountyInput) (common.Hash, error) { txnArgs := types.TransactionOptions{ Client: client, - AccountAddress: redeemBountyInput.Address, - Password: redeemBountyInput.Password, + Account: redeemBountyInput.Account, ChainId: core.ChainId, Config: config, ContractAddress: core.StakeManagerAddress, diff --git a/cmd/claimBounty_test.go b/cmd/claimBounty_test.go index a5fdc1d41..848540833 100644 --- a/cmd/claimBounty_test.go +++ b/cmd/claimBounty_test.go @@ -1,12 +1,10 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "io/fs" "math/big" + "razor/accounts" "razor/cmd/mocks" "razor/core" "razor/core/types" @@ -126,7 +124,8 @@ func TestExecuteClaimBounty(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetUint32BountyId", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.bountyId, tt.args.bountyIdErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) @@ -148,9 +147,6 @@ func TestExecuteClaimBounty(t *testing.T) { } func TestClaimBounty(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var config types.Configurations var client *ethclient.Client var bountyInput types.RedeemBountyInput @@ -273,7 +269,7 @@ func TestClaimBounty(t *testing.T) { stakeManagerMock.On("GetBountyLock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.CallOpts"), mock.AnythingOfType("uint32")).Return(tt.args.bountyLock, tt.args.bountyLockErr) timeMock.On("Sleep", mock.AnythingOfType("time.Duration")).Return() utilsMock.On("CalculateBlockTime", mock.AnythingOfType("*ethclient.Client")).Return(blockTime) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) stakeManagerMock.On("RedeemBounty", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.AnythingOfType("uint32")).Return(tt.args.redeemBountyTxn, tt.args.redeemBountyErr) utilsMock.On("SecondsToReadableTime", mock.AnythingOfType("int")).Return(tt.args.time) transactionUtilsMock.On("Hash", mock.Anything).Return(tt.args.hash) diff --git a/cmd/claimCommission.go b/cmd/claimCommission.go index e79b71db7..6de6853df 100644 --- a/cmd/claimCommission.go +++ b/cmd/claimCommission.go @@ -3,6 +3,7 @@ package cmd import ( "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -45,7 +46,12 @@ func (*UtilsStruct) ClaimCommission(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) stakerId, err := razorUtils.GetStakerId(client, address) @@ -60,14 +66,13 @@ func (*UtilsStruct) ClaimCommission(flagSet *pflag.FlagSet) { if stakerInfo.StakerReward.Cmp(big.NewInt(0)) > 0 { txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - AccountAddress: address, - Password: password, ChainId: core.ChainId, Config: config, ContractAddress: core.StakeManagerAddress, MethodName: "claimStakerReward", Parameters: []interface{}{}, ABI: bindings.StakeManagerMetaData.ABI, + Account: account, }) log.Info("Claiming commission...") diff --git a/cmd/claimCommission_test.go b/cmd/claimCommission_test.go index c79948342..50a252130 100644 --- a/cmd/claimCommission_test.go +++ b/cmd/claimCommission_test.go @@ -1,11 +1,9 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core/types" "testing" @@ -22,9 +20,6 @@ func TestClaimCommission(t *testing.T) { var flagSet *pflag.FlagSet var callOpts bind.CallOpts - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31337)) - type args struct { config types.Configurations configErr error @@ -207,9 +202,10 @@ func TestClaimCommission(t *testing.T) { utilsMock.On("GetStakerId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("GetOptions").Return(callOpts) utilsMock.On("AssignPassword", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(nil) stakeManagerMock.On("StakerInfo", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.CallOpts"), mock.AnythingOfType("uint32")).Return(tt.args.stakerInfo, tt.args.stakerInfoErr) diff --git a/cmd/cmd-utils.go b/cmd/cmd-utils.go index 30a84e1cd..2583c3586 100644 --- a/cmd/cmd-utils.go +++ b/cmd/cmd-utils.go @@ -23,7 +23,12 @@ func (*UtilsStruct) GetEpochAndState(client *ethclient.Client) (uint32, int64, e if err != nil { return 0, 0, err } - state, err := razorUtils.GetBufferedState(client, bufferPercent) + latestHeader, err := clientUtils.GetLatestBlockWithRetry(client) + if err != nil { + log.Error("Error in fetching block: ", err) + return 0, 0, err + } + state, err := razorUtils.GetBufferedState(client, latestHeader, bufferPercent) if err != nil { return 0, 0, err } diff --git a/cmd/cmd-utils_test.go b/cmd/cmd-utils_test.go index 63f543611..0da68b1e5 100644 --- a/cmd/cmd-utils_test.go +++ b/cmd/cmd-utils_test.go @@ -2,6 +2,7 @@ package cmd import ( "errors" + Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/spf13/pflag" "github.com/stretchr/testify/mock" @@ -15,6 +16,8 @@ func TestGetEpochAndState(t *testing.T) { type args struct { epoch uint32 epochErr error + latestHeader *Types.Header + latestHeaderErr error bufferPercent int32 bufferPercentErr error state int64 @@ -32,6 +35,7 @@ func TestGetEpochAndState(t *testing.T) { name: "Test 1: When GetEpochAndState function executes successfully", args: args{ epoch: 4, + latestHeader: &Types.Header{}, bufferPercent: 20, state: 0, stateName: "commit", @@ -44,6 +48,7 @@ func TestGetEpochAndState(t *testing.T) { name: "Test 2: When there is an error in getting epoch", args: args{ epochErr: errors.New("epoch error"), + latestHeader: &Types.Header{}, bufferPercent: 20, state: 0, stateName: "commit", @@ -56,6 +61,7 @@ func TestGetEpochAndState(t *testing.T) { name: "Test 3: When there is an error in getting bufferPercent", args: args{ epoch: 4, + latestHeader: &Types.Header{}, bufferPercentErr: errors.New("bufferPercent error"), state: 0, stateName: "commit", @@ -68,6 +74,7 @@ func TestGetEpochAndState(t *testing.T) { name: "Test 4: When there is an error in getting state", args: args{ epoch: 4, + latestHeader: &Types.Header{}, bufferPercent: 20, stateErr: errors.New("state error"), }, @@ -75,6 +82,19 @@ func TestGetEpochAndState(t *testing.T) { wantState: 0, wantErr: errors.New("state error"), }, + { + name: "Test 5: When there is an error in getting latest header", + args: args{ + epoch: 4, + latestHeaderErr: errors.New("header error"), + bufferPercent: 20, + state: 0, + stateName: "commit", + }, + wantEpoch: 0, + wantState: 0, + wantErr: errors.New("header error"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -82,7 +102,8 @@ func TestGetEpochAndState(t *testing.T) { utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.epochErr) cmdUtilsMock.On("GetBufferPercent").Return(tt.args.bufferPercent, tt.args.bufferPercentErr) - utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("int32")).Return(tt.args.state, tt.args.stateErr) + clientUtilsMock.On("GetLatestBlockWithRetry", mock.Anything).Return(tt.args.latestHeader, tt.args.latestHeaderErr) + utilsMock.On("GetBufferedState", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) utils := &UtilsStruct{} gotEpoch, gotState, err := utils.GetEpochAndState(client) diff --git a/cmd/commit.go b/cmd/commit.go index 83a496142..df340bf5d 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -10,8 +10,11 @@ import ( "razor/core/types" "razor/pkg/bindings" "razor/utils" + "sync" "time" + Types "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" solsha3 "github.com/miguelmota/go-solidity-sha3" @@ -55,7 +58,7 @@ func (*UtilsStruct) GetSalt(client *ethclient.Client, epoch uint32) ([32]byte, e HandleCommitState fetches the collections assigned to the staker and creates the leaves required for the merkle tree generation. Values for only the collections assigned to the staker is fetched for others, 0 is added to the leaves of tree. */ -func (*UtilsStruct) HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, rogueData types.Rogue) (types.CommitData, error) { +func (*UtilsStruct) HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, commitParams *types.CommitParams, rogueData types.Rogue) (types.CommitData, error) { numActiveCollections, err := razorUtils.GetNumActiveCollections(client) if err != nil { return types.CommitData{}, err @@ -67,40 +70,76 @@ func (*UtilsStruct) HandleCommitState(client *ethclient.Client, epoch uint32, se return types.CommitData{}, err } - var leavesOfTree []*big.Int + leavesOfTree := make([]*big.Int, numActiveCollections) + results := make(chan types.CollectionResult, numActiveCollections) + errChan := make(chan error, numActiveCollections) + + defer close(results) + defer close(errChan) + + var wg sync.WaitGroup log.Debug("Creating a local cache which will store API result and expire at the end of commit state") - localCache := cache.NewLocalCache(time.Second * time.Duration(core.StateLength)) + commitParams.LocalCache = cache.NewLocalCache(time.Second * time.Duration(core.StateLength)) log.Debug("Iterating over all the collections...") for i := 0; i < int(numActiveCollections); i++ { - log.Debug("HandleCommitState: Iterating index: ", i) - log.Debug("HandleCommitState: Is the collection assigned: ", assignedCollections[i]) - if assignedCollections[i] { - collectionId, err := razorUtils.GetCollectionIdFromIndex(client, uint16(i)) - if err != nil { - return types.CommitData{}, err + wg.Add(1) + go func(i int) { + defer wg.Done() + var leaf *big.Int + + log.Debugf("HandleCommitState: Is the collection at iterating index %v assigned: %v ", i, assignedCollections[i]) + if assignedCollections[i] { + collectionId, err := razorUtils.GetCollectionIdFromIndex(client, uint16(i)) + if err != nil { + log.Error("Error in getting collection ID: ", err) + errChan <- err + return + } + collectionData, err := razorUtils.GetAggregatedDataOfCollection(client, collectionId, epoch, commitParams) + if err != nil { + log.Error("Error in getting aggregated data of collection: ", err) + errChan <- err + return + } + if rogueData.IsRogue && utils.Contains(rogueData.RogueMode, "commit") { + log.Warn("YOU ARE COMMITTING VALUES IN ROGUE MODE, THIS CAN INCUR PENALTIES!") + collectionData = razorUtils.GetRogueRandomValue(100000) + log.Debug("HandleCommitState: Collection data in rogue mode: ", collectionData) + } + log.Debugf("HandleCommitState: Data of collection %d: %s", collectionId, collectionData) + leaf = collectionData + } else { + leaf = big.NewInt(0) } - collectionData, err := razorUtils.GetAggregatedDataOfCollection(client, collectionId, epoch, localCache) + log.Debugf("Sending index: %v, leaf data: %v to results channel", i, leaf) + results <- types.CollectionResult{Index: i, Leaf: leaf} + }(i) + } + + wg.Wait() + + for i := 0; i < int(numActiveCollections); i++ { + select { + case result := <-results: + log.Infof("Received from results: Index: %d, Leaf: %v", result.Index, result.Leaf) + leavesOfTree[result.Index] = result.Leaf + case err := <-errChan: if err != nil { + // Returning the first error from the error channel + log.Error("Error in getting collection data: ", err) + commitParams.LocalCache.StopCleanup() return types.CommitData{}, err } - if rogueData.IsRogue && utils.Contains(rogueData.RogueMode, "commit") { - log.Warn("YOU ARE COMMITTING VALUES IN ROGUE MODE, THIS CAN INCUR PENALTIES!") - collectionData = razorUtils.GetRogueRandomValue(100000) - log.Debug("HandleCommitState: Collection data in rogue mode: ", collectionData) - } - log.Debugf("HandleCommitState: Data of collection %d: %s", collectionId, collectionData) - leavesOfTree = append(leavesOfTree, collectionData) - } else { - leavesOfTree = append(leavesOfTree, big.NewInt(0)) } } + log.Debug("HandleCommitState: Assigned Collections: ", assignedCollections) log.Debug("HandleCommitState: SeqAllottedCollections: ", seqAllottedCollections) log.Debug("HandleCommitState: Leaves: ", leavesOfTree) - localCache.StopCleanup() + commitParams.LocalCache.StopCleanup() return types.CommitData{ AssignedCollections: assignedCollections, @@ -112,8 +151,8 @@ func (*UtilsStruct) HandleCommitState(client *ethclient.Client, epoch uint32, se /* Commit finally commits the data to the smart contract. It calculates the commitment to send using the merkle tree root and the seed. */ -func (*UtilsStruct) Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, seed []byte, values []*big.Int) (common.Hash, error) { - if state, err := razorUtils.GetBufferedState(client, config.BufferPercent); err != nil || state != 0 { +func (*UtilsStruct) Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *Types.Header, seed []byte, values []*big.Int) (common.Hash, error) { + if state, err := razorUtils.GetBufferedState(client, latestHeader, config.BufferPercent); err != nil || state != 0 { log.Error("Not commit state") return core.NilHash, err } @@ -126,14 +165,13 @@ func (*UtilsStruct) Commit(client *ethclient.Client, config types.Configurations txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.VoteManagerAddress, ABI: bindings.VoteManagerMetaData.ABI, MethodName: "commit", Parameters: []interface{}{epoch, commitmentToSend}, + Account: account, }) log.Info("Commitment sent...") diff --git a/cmd/commit_test.go b/cmd/commit_test.go index e42dc5c53..bba1409a7 100644 --- a/cmd/commit_test.go +++ b/cmd/commit_test.go @@ -1,42 +1,38 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "encoding/hex" "errors" "fmt" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" "github.com/stretchr/testify/mock" "math/big" + "razor/cache" "razor/core" "razor/core/types" "razor/pkg/bindings" "razor/utils" "reflect" "testing" + "time" ) func TestCommit(t *testing.T) { var ( - client *ethclient.Client - account types.Account - config types.Configurations - seed []byte - epoch uint32 + client *ethclient.Client + account types.Account + config types.Configurations + latestHeader *Types.Header + seed []byte + epoch uint32 ) - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) type args struct { values []*big.Int state int64 stateErr error - txnOpts *bind.TransactOpts commitTxn *Types.Transaction commitErr error hash common.Hash @@ -53,7 +49,6 @@ func TestCommit(t *testing.T) { values: []*big.Int{big.NewInt(1)}, state: 0, stateErr: nil, - txnOpts: txnOpts, commitTxn: &Types.Transaction{}, commitErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -66,7 +61,6 @@ func TestCommit(t *testing.T) { args: args{ values: []*big.Int{big.NewInt(1)}, stateErr: errors.New("state error"), - txnOpts: txnOpts, commitTxn: &Types.Transaction{}, commitErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -80,7 +74,6 @@ func TestCommit(t *testing.T) { values: []*big.Int{big.NewInt(1)}, state: 0, stateErr: nil, - txnOpts: txnOpts, commitTxn: &Types.Transaction{}, commitErr: errors.New("commit error"), hash: common.BigToHash(big.NewInt(1)), @@ -104,13 +97,13 @@ func TestCommit(t *testing.T) { utils.MerkleInterface = &utils.MerkleTreeStruct{} merkleUtils = utils.MerkleInterface - utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("int32")).Return(tt.args.state, tt.args.stateErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(tt.args.txnOpts) + utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) voteManagerMock.On("Commit", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.AnythingOfType("uint32"), mock.Anything).Return(tt.args.commitTxn, tt.args.commitErr) transactionMock.On("Hash", mock.AnythingOfType("*types.Transaction")).Return(tt.args.hash) utils := &UtilsStruct{} - got, err := utils.Commit(client, config, account, epoch, seed, tt.args.values) + got, err := utils.Commit(client, config, account, epoch, latestHeader, seed, tt.args.values) if got != tt.want { t.Errorf("Txn hash for Commit function, got = %v, want = %v", got, tt.want) } @@ -233,16 +226,21 @@ func TestHandleCommitState(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + localCache := cache.NewLocalCache(time.Second * 10) + commitParams := &types.CommitParams{ + LocalCache: localCache, + } + SetUpMockInterfaces() utilsMock.On("GetNumActiveCollections", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.numActiveCollections, tt.args.numActiveCollectionsErr) utilsMock.On("GetAssignedCollections", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything).Return(tt.args.assignedCollections, tt.args.seqAllottedCollections, tt.args.assignedCollectionsErr) utilsMock.On("GetCollectionIdFromIndex", mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(tt.args.collectionId, tt.args.collectionIdErr) - utilsMock.On("GetAggregatedDataOfCollection", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything).Return(tt.args.collectionData, tt.args.collectionDataErr) + utilsMock.On("GetAggregatedDataOfCollection", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.collectionData, tt.args.collectionDataErr) utilsMock.On("GetRogueRandomValue", mock.Anything).Return(rogueValue) utils := &UtilsStruct{} - got, err := utils.HandleCommitState(client, epoch, seed, tt.args.rogueData) + got, err := utils.HandleCommitState(client, epoch, seed, commitParams, tt.args.rogueData) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Data from HandleCommitState function, got = %v, want = %v", got, tt.want) } @@ -398,16 +396,21 @@ func BenchmarkHandleCommitState(b *testing.B) { for _, v := range table { b.Run(fmt.Sprintf("Number_Of_Active_Collections%d", v.numActiveCollections), func(b *testing.B) { for i := 0; i < b.N; i++ { + localCache := cache.NewLocalCache(time.Second * 10) + commitParams := &types.CommitParams{ + LocalCache: localCache, + } + SetUpMockInterfaces() utilsMock.On("GetNumActiveCollections", mock.AnythingOfType("*ethclient.Client")).Return(v.numActiveCollections, nil) utilsMock.On("GetAssignedCollections", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything).Return(v.assignedCollections, nil, nil) utilsMock.On("GetCollectionIdFromIndex", mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(uint16(1), nil) - utilsMock.On("GetAggregatedDataOfCollection", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything).Return(big.NewInt(1000), nil) + utilsMock.On("GetAggregatedDataOfCollection", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(big.NewInt(1000), nil) utilsMock.On("GetRogueRandomValue", mock.Anything).Return(rogueValue) ut := &UtilsStruct{} - _, err := ut.HandleCommitState(client, epoch, seed, types.Rogue{IsRogue: false}) + _, err := ut.HandleCommitState(client, epoch, seed, commitParams, types.Rogue{IsRogue: false}) if err != nil { log.Fatal(err) } diff --git a/cmd/config-utils.go b/cmd/config-utils.go index 7ad7073e9..a5e5839a1 100644 --- a/cmd/config-utils.go +++ b/cmd/config-utils.go @@ -1,4 +1,4 @@ -//Package cmd provides all functions related to command line +// Package cmd provides all functions related to command line package cmd import ( @@ -99,7 +99,6 @@ func (*UtilsStruct) GetConfigData() (types.Configurations, error) { config.RPCTimeout = rpcTimeout utils.RPCTimeout = rpcTimeout config.HTTPTimeout = httpTimeout - utils.HTTPTimeout = httpTimeout config.LogFileMaxSize = logFileMaxSize config.LogFileMaxBackups = logFileMaxBackups config.LogFileMaxAge = logFileMaxAge @@ -109,245 +108,295 @@ func (*UtilsStruct) GetConfigData() (types.Configurations, error) { return config, nil } +func getConfigValueForKey(key string, dataType string) interface{} { + switch dataType { + case "string": + return viper.GetString(key) + case "float32": // Note: viper doesn't have GetFloat32 + return float32(viper.GetFloat64(key)) + case "float64": + return viper.GetFloat64(key) + case "int": + return viper.GetInt(key) + case "int32": + return viper.GetInt32(key) + case "int64": + return viper.GetInt64(key) + case "uint64": + return viper.GetUint64(key) + default: + log.Fatalf("Unsupported data type: %s", dataType) + return nil + } +} + +func getConfigValue(flagName string, dataType string, defaultReturnValue interface{}, viperKey string) (interface{}, error) { + // Check if the config parameter was passed as a root flag in the command. + if flagSetUtils.Changed(rootCmd.Flags(), flagName) { + // Getting the root flag input + rootFlagValue, err := flagSetUtils.FetchRootFlagInput(flagName, dataType) + if err != nil { + log.Errorf("Error in getting value from root flag") + return defaultReturnValue, err + } + log.Debugf("%v flag passed as root flag, Taking value of config %v = %v ", flagName, flagName, rootFlagValue) + return rootFlagValue, nil + } + + // Checking if value of config parameter is present in config file + if viper.IsSet(viperKey) { + valueForKey := getConfigValueForKey(viperKey, dataType) + log.Debugf("Taking value of config %v = %v from config file", viperKey, valueForKey) + return valueForKey, nil + } + log.Debugf("%v config is not set, taking its default value %v", viperKey, defaultReturnValue) + return defaultReturnValue, nil +} + //This function returns the provider func (*UtilsStruct) GetProvider() (string, error) { - provider, err := flagSetUtils.GetRootStringProvider() + provider, err := getConfigValue("provider", "string", "", "provider") if err != nil { return "", err } - if provider == "" { - if viper.IsSet("provider") { - provider = viper.GetString("provider") - } else { - log.Error("Provider is not set in config file") - return "", errors.New("provider is not set") - } + providerString := provider.(string) + if providerString == "" { + return "", errors.New("provider is not set") } - if !strings.HasPrefix(provider, "https") { + if !strings.HasPrefix(providerString, "https") { log.Warn("You are not using a secure RPC URL. Switch to an https URL instead to be safe.") } - return provider, nil + return providerString, nil } //This function returns the alternate provider func (*UtilsStruct) GetAlternateProvider() (string, error) { - alternateProvider, err := flagSetUtils.GetRootStringAlternateProvider() + alternateProvider, err := getConfigValue("alternateProvider", "string", "", "alternateProvider") if err != nil { return "", err } - if alternateProvider == "" { - if viper.IsSet("alternateProvider") { - alternateProvider = viper.GetString("alternateProvider") - } else { - alternateProvider = "" - log.Debug("alternate provider is not set, taking its nil value ", alternateProvider) - } - } - if !strings.HasPrefix(alternateProvider, "https") { + alternateProviderString := alternateProvider.(string) + if !strings.HasPrefix(alternateProviderString, "https") { log.Warn("You are not using a secure RPC URL. Switch to an https URL instead to be safe.") } - return alternateProvider, nil + return alternateProviderString, nil } //This function returns the multiplier func (*UtilsStruct) GetMultiplier() (float32, error) { - gasMultiplier, err := flagSetUtils.GetRootFloat32GasMultiplier() + const ( + MinMultiplier = 1.0 // Minimum multiplier value + MaxMultiplier = 3.0 // Maximum multiplier value + ) + + gasMultiplier, err := getConfigValue("gasmultiplier", "float32", core.DefaultGasMultiplier, "gasmultiplier") if err != nil { - return float32(core.DefaultGasMultiplier), err - } - if gasMultiplier == -1 { - if viper.IsSet("gasmultiplier") { - gasMultiplier = float32(viper.GetFloat64("gasmultiplier")) - } else { - gasMultiplier = float32(core.DefaultGasMultiplier) - log.Debug("GasMultiplier is not set, taking its default value ", gasMultiplier) - } + return core.DefaultGasMultiplier, err } - return gasMultiplier, nil + + multiplierFloat32 := gasMultiplier.(float32) + + // Validate multiplier range + if multiplierFloat32 < MinMultiplier || multiplierFloat32 > MaxMultiplier { + log.Infof("GasMultiplier %.2f is out of the valid range (%.1f-%.1f), using default value %.2f", multiplierFloat32, MinMultiplier, MaxMultiplier, core.DefaultGasMultiplier) + return core.DefaultGasMultiplier, nil + } + + return multiplierFloat32, nil } //This function returns the buffer percent func (*UtilsStruct) GetBufferPercent() (int32, error) { - bufferPercent, err := flagSetUtils.GetRootInt32Buffer() + const ( + MinBufferPercent = 0 + MaxBufferPercent = 30 + ) + + bufferPercent, err := getConfigValue("buffer", "int32", core.DefaultBufferPercent, "buffer") if err != nil { - return int32(core.DefaultBufferPercent), err - } - if bufferPercent == 0 { - if viper.IsSet("buffer") { - bufferPercent = viper.GetInt32("buffer") - } else { - bufferPercent = int32(core.DefaultBufferPercent) - log.Debug("BufferPercent is not set, taking its default value ", bufferPercent) - } + return core.DefaultBufferPercent, err + } + + bufferPercentInt32 := bufferPercent.(int32) + + // Check if bufferPercent is explicitly set and not within the valid range. + if bufferPercentInt32 < MinBufferPercent || bufferPercentInt32 > MaxBufferPercent { + log.Infof("BufferPercent %d is out of the valid range (%d-%d), using default value %d", bufferPercentInt32, MinBufferPercent, MaxBufferPercent, core.DefaultBufferPercent) + return core.DefaultBufferPercent, nil + } + + // If bufferPercent is 0, use the default value. + if bufferPercentInt32 == 0 { + log.Debugf("BufferPercent is unset or set to 0, using its default %d value", core.DefaultBufferPercent) + return core.DefaultBufferPercent, nil } - return bufferPercent, nil + + return bufferPercentInt32, nil } //This function returns the wait time func (*UtilsStruct) GetWaitTime() (int32, error) { - waitTime, err := flagSetUtils.GetRootInt32Wait() + const ( + MinWaitTime = 1 // Minimum wait time in seconds + MaxWaitTime = 30 // Maximum wait time in seconds + ) + + waitTime, err := getConfigValue("wait", "int32", core.DefaultWaitTime, "wait") if err != nil { - return int32(core.DefaultWaitTime), err - } - if waitTime == -1 { - if viper.IsSet("wait") { - waitTime = viper.GetInt32("wait") - } else { - waitTime = int32(core.DefaultWaitTime) - log.Debug("WaitTime is not set, taking its default value ", waitTime) - } + return core.DefaultWaitTime, err + } + + waitTimeInt32 := waitTime.(int32) + + // Validate waitTime range + if waitTimeInt32 < MinWaitTime || waitTimeInt32 > MaxWaitTime { + log.Infof("WaitTime %d is out of the valid range (%d-%d), using default value %d", waitTimeInt32, MinWaitTime, MaxWaitTime, core.DefaultWaitTime) + return core.DefaultWaitTime, nil } - return waitTime, nil + + return waitTimeInt32, nil } //This function returns the gas price func (*UtilsStruct) GetGasPrice() (int32, error) { - gasPrice, err := flagSetUtils.GetRootInt32GasPrice() + gasPrice, err := getConfigValue("gasprice", "int32", core.DefaultGasPrice, "gasprice") if err != nil { - return int32(core.DefaultGasPrice), err + return core.DefaultGasPrice, err } - if gasPrice == -1 { - if viper.IsSet("gasprice") { - gasPrice = viper.GetInt32("gasprice") - } else { - gasPrice = int32(core.DefaultGasPrice) - log.Debug("GasPrice is not set, taking its default value ", gasPrice) - } + gasPriceInt32 := gasPrice.(int32) + + // Validate gasPrice value + if gasPriceInt32 != 0 && gasPriceInt32 != 1 { + log.Infof("GasPrice %d is invalid, using default value %d", gasPriceInt32, core.DefaultGasPrice) + return core.DefaultGasPrice, nil } - return gasPrice, nil + + return gasPriceInt32, nil } //This function returns the log level func (*UtilsStruct) GetLogLevel() (string, error) { - logLevel, err := flagSetUtils.GetRootStringLogLevel() + logLevel, err := getConfigValue("logLevel", "string", core.DefaultLogLevel, "logLevel") if err != nil { return core.DefaultLogLevel, err } - if logLevel == "" { - if viper.IsSet("logLevel") { - logLevel = viper.GetString("logLevel") - } else { - logLevel = core.DefaultLogLevel - log.Debug("LogLevel is not set, taking its default value ", logLevel) - } - } - return logLevel, nil + return logLevel.(string), nil } //This function returns the gas limit func (*UtilsStruct) GetGasLimit() (float32, error) { - gasLimit, err := flagSetUtils.GetRootFloat32GasLimit() + //gasLimit in the config acts as a gasLimit multiplier + const ( + MinGasLimit = 1.0 // Minimum gas limit + MaxGasLimit = 3.0 // Maximum gas limit + ) + + gasLimit, err := getConfigValue("gasLimit", "float32", core.DefaultGasLimit, "gasLimit") if err != nil { - return float32(core.DefaultGasLimit), err - } - if gasLimit == -1 { - if viper.IsSet("gasLimit") { - gasLimit = float32(viper.GetFloat64("gasLimit")) - } else { - gasLimit = float32(core.DefaultGasLimit) - log.Debug("GasLimit is not set, taking its default value ", gasLimit) - } + return core.DefaultGasLimit, err } - return gasLimit, nil + + gasLimitFloat32 := gasLimit.(float32) + + // Validate gasLimit range + if gasLimitFloat32 < MinGasLimit || gasLimitFloat32 > MaxGasLimit { + log.Warnf("GasLimit %.2f is out of the suggested range (%.1f-%.1f), using default value %.2f", gasLimitFloat32, MinGasLimit, MaxGasLimit, core.DefaultGasLimit) + } + + return gasLimitFloat32, nil } //This function returns the gas limit to override func (*UtilsStruct) GetGasLimitOverride() (uint64, error) { - gasLimitOverride, err := flagSetUtils.GetRootUint64GasLimitOverride() + const ( + MinGasLimitOverride = 10000000 // Minimum gas limit override + MaxGasLimitOverride = 50000000 // Maximum gas limit override + ) + + gasLimitOverride, err := getConfigValue("gasLimitOverride", "uint64", core.DefaultGasLimitOverride, "gasLimitOverride") if err != nil { - return uint64(core.DefaultGasLimitOverride), err - } - if gasLimitOverride == 0 { - if viper.IsSet("gasLimitOverride") { - gasLimitOverride = viper.GetUint64("gasLimitOverride") - } else { - gasLimitOverride = uint64(core.DefaultGasLimitOverride) - log.Debug("GasLimitOverride is not set, taking its default value ", gasLimitOverride) - } + return core.DefaultGasLimitOverride, err + } + + gasLimitOverrideUint64 := gasLimitOverride.(uint64) + + // Validate gasLimitOverride range + if gasLimitOverrideUint64 < MinGasLimitOverride || gasLimitOverrideUint64 > MaxGasLimitOverride { + log.Infof("GasLimitOverride %d is out of the valid range (%d-%d), using default value %d", gasLimitOverrideUint64, MinGasLimitOverride, MaxGasLimitOverride, core.DefaultGasLimitOverride) + return core.DefaultGasLimitOverride, nil } - return gasLimitOverride, nil + + return gasLimitOverrideUint64, nil } //This function returns the RPC timeout func (*UtilsStruct) GetRPCTimeout() (int64, error) { - rpcTimeout, err := flagSetUtils.GetRootInt64RPCTimeout() + const ( + MinRPCTimeout = 10 // Minimum RPC timeout in seconds + MaxRPCTimeout = 60 // Maximum RPC timeout in seconds + ) + + rpcTimeout, err := getConfigValue("rpcTimeout", "int64", core.DefaultRPCTimeout, "rpcTimeout") if err != nil { - return int64(core.DefaultRPCTimeout), err - } - if rpcTimeout == 0 { - if viper.IsSet("rpcTimeout") { - rpcTimeout = viper.GetInt64("rpcTimeout") - } else { - rpcTimeout = int64(core.DefaultRPCTimeout) - log.Debug("RPCTimeout is not set, taking its default value ", rpcTimeout) - } + return core.DefaultRPCTimeout, err } - return rpcTimeout, nil + + rpcTimeoutInt64 := rpcTimeout.(int64) + + // Validate rpcTimeout range + if rpcTimeoutInt64 < MinRPCTimeout || rpcTimeoutInt64 > MaxRPCTimeout { + log.Infof("RPCTimeout %d is out of the valid range (%d-%d), using default value %d", rpcTimeoutInt64, MinRPCTimeout, MaxRPCTimeout, core.DefaultRPCTimeout) + return core.DefaultRPCTimeout, nil + } + + return rpcTimeoutInt64, nil } func (*UtilsStruct) GetHTTPTimeout() (int64, error) { - httpTimeout, err := flagSetUtils.GetRootInt64HTTPTimeout() + const ( + MinHTTPTimeout = 10 // Minimum HTTP timeout in seconds + MaxHTTPTimeout = 60 // Maximum HTTP timeout in seconds + ) + + httpTimeout, err := getConfigValue("httpTimeout", "int64", core.DefaultHTTPTimeout, "httpTimeout") if err != nil { - return int64(core.DefaultHTTPTimeout), err - } - if httpTimeout == 0 { - if viper.IsSet("httpTimeout") { - httpTimeout = viper.GetInt64("httpTimeout") - } else { - httpTimeout = int64(core.DefaultRPCTimeout) - log.Debug("HTTPTimeout is not set, taking its default value ", httpTimeout) - } + return core.DefaultHTTPTimeout, err + } + + httpTimeoutInt64 := httpTimeout.(int64) + + // Validate httpTimeout range + if httpTimeoutInt64 < MinHTTPTimeout || httpTimeoutInt64 > MaxHTTPTimeout { + log.Infof("HTTPTimeout %d is out of the valid range (%d-%d), using default value %d", httpTimeoutInt64, MinHTTPTimeout, MaxHTTPTimeout, core.DefaultHTTPTimeout) + return core.DefaultHTTPTimeout, nil } - return httpTimeout, nil + + return httpTimeoutInt64, nil } func (*UtilsStruct) GetLogFileMaxSize() (int, error) { - logFileMaxSize, err := flagSetUtils.GetRootIntLogFileMaxSize() + logFileMaxSize, err := getConfigValue("logFileMaxSize", "int", core.DefaultLogFileMaxSize, "logFileMaxSize") if err != nil { return core.DefaultLogFileMaxSize, err } - if logFileMaxSize == 0 { - if viper.IsSet("logFileMaxSize") { - logFileMaxSize = viper.GetInt("logFileMaxSize") - } else { - logFileMaxSize = core.DefaultLogFileMaxSize - log.Debug("logFileMaxSize is not set, taking its default value ", logFileMaxSize) - } - } - return logFileMaxSize, nil + return logFileMaxSize.(int), nil } func (*UtilsStruct) GetLogFileMaxBackups() (int, error) { - logFileMaxBackups, err := flagSetUtils.GetRootIntLogFileMaxBackups() + logFileMaxBackups, err := getConfigValue("logFileMaxBackups", "int", core.DefaultLogFileMaxBackups, "logFileMaxBackups") if err != nil { return core.DefaultLogFileMaxBackups, err } - if logFileMaxBackups == 0 { - if viper.IsSet("logFileMaxBackups") { - logFileMaxBackups = viper.GetInt("logFileMaxBackups") - } else { - logFileMaxBackups = core.DefaultLogFileMaxBackups - log.Debug("logFileMaxBackups is not set, taking its default value ", logFileMaxBackups) - } - } - return logFileMaxBackups, nil + return logFileMaxBackups.(int), nil } func (*UtilsStruct) GetLogFileMaxAge() (int, error) { - logFileMaxAge, err := flagSetUtils.GetRootIntLogFileMaxAge() + logFileMaxAge, err := getConfigValue("logFileMaxAge", "int", core.DefaultLogFileMaxAge, "logFileMaxAge") if err != nil { return core.DefaultLogFileMaxAge, err } - if logFileMaxAge == 0 { - if viper.IsSet("logFileMaxAge") { - logFileMaxAge = viper.GetInt("logFileMaxAge") - } else { - logFileMaxAge = core.DefaultLogFileMaxAge - log.Debug("logFileMaxAge is not set, taking its default value ", logFileMaxAge) - } - } - return logFileMaxAge, nil + return logFileMaxAge.(int), nil } //This function sets the log level @@ -356,18 +405,7 @@ func setLogLevel(config types.Configurations) { log.SetLevel(logrus.DebugLevel) } - log.Debug("Config details: ") - log.Debugf("Provider: %s", config.Provider) - log.Debugf("Alternate Provider: %s", config.AlternateProvider) - log.Debugf("Gas Multiplier: %.2f", config.GasMultiplier) - log.Debugf("Buffer Percent: %d", config.BufferPercent) - log.Debugf("Wait Time: %d", config.WaitTime) - log.Debugf("Gas Price: %d", config.GasPrice) - log.Debugf("Log Level: %s", config.LogLevel) - log.Debugf("Gas Limit: %.2f", config.GasLimitMultiplier) - log.Debugf("Gas Limit Override: %d", config.GasLimitOverride) - log.Debugf("RPC Timeout: %d", config.RPCTimeout) - log.Debugf("HTTP Timeout: %d", config.HTTPTimeout) + log.Debugf("Config details: %+v", config) if razorUtils.IsFlagPassed("logFile") { log.Debugf("Log File Max Size: %d MB", config.LogFileMaxSize) diff --git a/cmd/config-utils_test.go b/cmd/config-utils_test.go index c6a2f0618..da9e80f42 100644 --- a/cmd/config-utils_test.go +++ b/cmd/config-utils_test.go @@ -2,13 +2,38 @@ package cmd import ( "errors" + "github.com/spf13/viper" "github.com/stretchr/testify/mock" + "os" + "path/filepath" "razor/cmd/mocks" + "razor/core" "razor/core/types" "reflect" + "strings" "testing" ) +var tempConfigPath = "test_config.yaml" + +func createTestConfig(t *testing.T, viperKey string, value interface{}) { + + // Set some values + viper.Set(viperKey, value) + + // Write the temporary config + if err := viper.WriteConfigAs(tempConfigPath); err != nil { + t.Fatalf("Failed to write temp config: %s", err) + } + + viper.SetConfigName(strings.TrimSuffix(tempConfigPath, filepath.Ext(tempConfigPath))) + viper.AddConfigPath(".") +} + +func removeTestConfig(path string) { + os.RemoveAll(path) +} + func TestGetConfigData(t *testing.T) { nilConfig := types.Configurations{ Provider: "", @@ -219,45 +244,74 @@ func TestGetConfigData(t *testing.T) { func TestGetBufferPercent(t *testing.T) { type args struct { - bufferPercent int32 - bufferPercentErr error + isFlagSet bool + bufferPercent int32 + bufferPercentErr error + bufferInTestConfig int32 } tests := []struct { - name string - args args - want int32 - wantErr error + name string + useDummyConfigFile bool + args args + want int32 + wantErr error }{ { - name: "Test 1: When getBufferPercent function executes successfully", + name: "Test 1: When buffer percent is fetched from root flag", args: args{ - bufferPercent: 20, + isFlagSet: true, + bufferPercent: 5, }, - want: 20, + want: 5, wantErr: nil, }, { - name: "Test 2: When bufferPercent is 0", + name: "Test 2: When there is an error in fetching buffer from root flag", args: args{ - bufferPercent: 0, + isFlagSet: true, + bufferPercentErr: errors.New("buffer percent error"), }, - want: 20, + want: core.DefaultBufferPercent, + wantErr: errors.New("buffer percent error"), + }, + { + name: "Test 3: When buffer value is fetched from config", + useDummyConfigFile: true, + args: args{ + bufferInTestConfig: 1, + }, + want: 1, + wantErr: nil, + }, + { + name: "Test 4: When buffer is not passed in root nor set in config", + want: core.DefaultBufferPercent, wantErr: nil, }, { - name: "Test 3: When there is an error in getting bufferPercent", + name: "Test 5: When buffer value is out of a valid range", + useDummyConfigFile: true, args: args{ - bufferPercentErr: errors.New("bufferPercent error"), + bufferInTestConfig: 40, }, - want: 20, - wantErr: errors.New("bufferPercent error"), + want: core.DefaultBufferPercent, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "buffer", tt.args.bufferInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootInt32Buffer").Return(tt.args.bufferPercent, tt.args.bufferPercentErr) + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.bufferPercent, tt.args.bufferPercentErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetBufferPercent() if got != tt.want { @@ -274,51 +328,84 @@ func TestGetBufferPercent(t *testing.T) { } }) } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + }) + } } func TestGetGasLimit(t *testing.T) { type args struct { - gasLimit float32 - gasLimitErr error + isFlagSet bool + gasLimit float32 + gasLimitErr error + gasLimitInTestConfig float32 } tests := []struct { - name string - args args - want float32 - wantErr error + name string + useDummyConfigFile bool + args args + want float32 + wantErr error }{ { - name: "Test 1: When getGasLimit function executes successfully", + name: "Test 1: When gasLimit is fetched from root flag", args: args{ - gasLimit: 4, + isFlagSet: true, + gasLimit: 2, }, - want: 4, + want: 2, wantErr: nil, }, { - name: "Test 2: When gasLimit is -1", + name: "Test 2: When there is an error in fetching gasLimit from root flag", args: args{ - gasLimit: -1, + isFlagSet: true, + gasLimitErr: errors.New("gasLimit error"), }, - want: 2, + want: core.DefaultGasLimit, + wantErr: errors.New("gasLimit error"), + }, + { + name: "Test 3: When gas value is fetched from config", + useDummyConfigFile: true, + args: args{ + gasLimitInTestConfig: 2.5, + }, + want: 2.5, + wantErr: nil, + }, + { + name: "Test 4: When gasLimit is not passed in root nor set in config", + want: core.DefaultGasLimit, wantErr: nil, }, { - name: "Test 3: When there is an error in getting gasLimit", + name: "Test 5: When gas limit value is out of valid range", + useDummyConfigFile: true, args: args{ - gasLimitErr: errors.New("gasLimit error"), + gasLimitInTestConfig: 3.5, }, - want: 2, - wantErr: errors.New("gasLimit error"), + want: 3.5, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "gasLimit", tt.args.gasLimitInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootFloat32GasLimit").Return(tt.args.gasLimit, tt.args.gasLimitErr) - utils := &UtilsStruct{} + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.gasLimit, tt.args.gasLimitErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetGasLimit() if got != tt.want { t.Errorf("getGasLimit() got = %v, want %v", got, tt.want) @@ -338,48 +425,76 @@ func TestGetGasLimit(t *testing.T) { func TestGetGasLimitOverride(t *testing.T) { type args struct { - gasLimitOverride uint64 - gasLimitOverrideErr error + isFlagSet bool + gasLimitOverride uint64 + gasLimitOverrideErr error + gasLimitOverrideInConfig uint64 } tests := []struct { - name string - args args - want uint64 - wantErr error + name string + useDummyConfigFile bool + args args + want uint64 + wantErr error }{ { - name: "Test 1: When getGasLimitOverride function executes successfully", + name: "Test 1: When gasLimitOverride is fetched from root flag", args: args{ - gasLimitOverride: 5000000, + isFlagSet: true, + gasLimitOverride: 40000000, }, - want: 5000000, + want: 40000000, wantErr: nil, }, { - name: "Test 2: When gasLimitOverride is 0", + name: "Test 2: When there is an error in fetching gasLimitOverride from root flag", + args: args{ + isFlagSet: true, + gasLimitOverrideErr: errors.New("gasLimitOverride error"), + }, + want: core.DefaultGasLimitOverride, + wantErr: errors.New("gasLimitOverride error"), + }, + { + name: "Test 3: When gasLimitOverride is fetched from config", + useDummyConfigFile: true, args: args{ - gasLimitOverride: 0, + gasLimitOverrideInConfig: 30000000, }, - want: 50000000, + want: 30000000, wantErr: nil, }, { - name: "Test 3: When there is an error in getting gasLimitOverride", + name: "Test 4: When gasLimitOverride is not passed in root nor set in config", + want: core.DefaultGasLimitOverride, + wantErr: nil, + }, + { + name: "Test 3: When gasLimitOverride is fetched from config", + useDummyConfigFile: true, args: args{ - gasLimitOverrideErr: errors.New("gasLimitOverride error"), + gasLimitOverrideInConfig: 60000000, }, - want: 50000000, - wantErr: errors.New("gasLimitOverride error"), + want: core.DefaultGasLimitOverride, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "gasLimitOverride", tt.args.gasLimitOverrideInConfig) + defer removeTestConfig(tempConfigPath) + } + flagSetUtilsMock := new(mocks.FlagSetInterface) flagSetUtils = flagSetUtilsMock - flagSetUtilsMock.On("GetRootUint64GasLimitOverride").Return(tt.args.gasLimitOverride, tt.args.gasLimitOverrideErr) - utils := &UtilsStruct{} + flagSetUtilsMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.gasLimitOverride, tt.args.gasLimitOverrideErr) + flagSetUtilsMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetGasLimitOverride() if got != tt.want { t.Errorf("getGasLimitOverride() got = %v, want %v", got, tt.want) @@ -399,48 +514,76 @@ func TestGetGasLimitOverride(t *testing.T) { func TestGetGasPrice(t *testing.T) { type args struct { - gasPrice int32 - gasPriceErr error + isFlagSet bool + gasPrice int32 + gasPriceErr error + gasPriceInTestConfig int32 } tests := []struct { - name string - args args - want int32 - wantErr error + name string + useDummyConfigFile bool + args args + want int32 + wantErr error }{ { - name: "Test 1: When getGasPrice function executes successfully", + name: "Test 1: When gasPrice is fetched from root flag", args: args{ - gasPrice: 1, + isFlagSet: true, + gasPrice: 1, }, want: 1, wantErr: nil, }, { - name: "Test 2: When gasPrice is -1", + name: "Test 2: When there is an error in fetching gasPrice from root flag", args: args{ - gasPrice: -1, + isFlagSet: true, + gasPriceErr: errors.New("gasPrice error"), }, - want: 1, + want: core.DefaultGasPrice, + wantErr: errors.New("gasPrice error"), + }, + { + name: "Test 3: When gasPrice value is fetched from config", + useDummyConfigFile: true, + args: args{ + gasPriceInTestConfig: 0, + }, + want: 0, wantErr: nil, }, { - name: "Test 3: When there is an error in getting gasPrice", + name: "Test 4: When gasPrice is not passed in root nor set in config", + want: core.DefaultGasPrice, + wantErr: nil, + }, + { + name: "Test 5: When gasPrice is out of valid range", + useDummyConfigFile: true, args: args{ - gasPriceErr: errors.New("gasPrice error"), + gasPriceInTestConfig: 3, }, - want: 1, - wantErr: errors.New("gasPrice error"), + want: core.DefaultGasPrice, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "gasprice", tt.args.gasPriceInTestConfig) + defer removeTestConfig(tempConfigPath) + } + flagSetUtilsMock := new(mocks.FlagSetInterface) flagSetUtils = flagSetUtilsMock - flagSetUtilsMock.On("GetRootInt32GasPrice").Return(tt.args.gasPrice, tt.args.gasPriceErr) - utils := &UtilsStruct{} + flagSetUtilsMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.gasPrice, tt.args.gasPriceErr) + flagSetUtilsMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetGasPrice() if got != tt.want { t.Errorf("getGasPrice() got = %v, want %v", got, tt.want) @@ -460,47 +603,66 @@ func TestGetGasPrice(t *testing.T) { func TestGetLogLevel(t *testing.T) { type args struct { - logLevel string - logLevelErr error + isFlagSet bool + logLevel string + logLevelErr error + logLevelInTestConfig string } tests := []struct { - name string - args args - want string - wantErr error + name string + useDummyConfigFile bool + args args + want string + wantErr error }{ { - name: "Test 1: When getLogLevel function executes successfully", + name: "Test 1: When logLevel is fetched from root flag", args: args{ - logLevel: "debug", + isFlagSet: true, + logLevel: "debug", }, want: "debug", wantErr: nil, }, { - name: "Test 2: When logLevel is nil", + name: "Test 2: When there is an error in fetching logLevel from root flag", args: args{ - logLevel: "", + isFlagSet: true, + logLevelErr: errors.New("logLevel error"), }, - want: "", - wantErr: nil, + want: core.DefaultLogLevel, + wantErr: errors.New("logLevel error"), }, { - name: "Test 3: When there is an error in getting logLevel", + name: "Test 3: When logLevel value is fetched from config", + useDummyConfigFile: true, args: args{ - logLevelErr: errors.New("logLevel error"), + logLevelInTestConfig: "info", }, - want: "", - wantErr: errors.New("logLevel error"), + want: "info", + wantErr: nil, + }, + { + name: "Test 4: When logLevel is not passed in root nor set in config", + want: core.DefaultLogLevel, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "logLevel", tt.args.logLevelInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootStringLogLevel").Return(tt.args.logLevel, tt.args.logLevelErr) - utils := &UtilsStruct{} + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.logLevel, tt.args.logLevelErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetLogLevel() if got != tt.want { t.Errorf("getLogLevel() got = %v, want %v", got, tt.want) @@ -520,47 +682,75 @@ func TestGetLogLevel(t *testing.T) { func TestGetMultiplier(t *testing.T) { type args struct { - gasMultiplier float32 - gasMultiplierErr error + isFlagSet bool + gasMultiplier float32 + gasMultiplierErr error + gasMultiplierInTestConfig float32 } tests := []struct { - name string - args args - want float32 - wantErr error + name string + useDummyConfigFile bool + args args + want float32 + wantErr error }{ { - name: "Test 1: When getMultiplier function executes successfully", + name: "Test 1: When gasMultiplier is fetched from root flag", args: args{ + isFlagSet: true, gasMultiplier: 2, }, want: 2, wantErr: nil, }, { - name: "Test 2: When gasMultiplier is -1", + name: "Test 2: When there is an error in fetching gasMultiplier from root flag", args: args{ - gasMultiplier: -1, + isFlagSet: true, + gasMultiplierErr: errors.New("gasMultiplier error"), }, - want: 1, + want: core.DefaultGasMultiplier, + wantErr: errors.New("gasMultiplier error"), + }, + { + name: "Test 3: When gasMultiplier value is fetched from config", + useDummyConfigFile: true, + args: args{ + gasMultiplierInTestConfig: 3, + }, + want: 3, wantErr: nil, }, { - name: "Test 3: When there is an error in getting gasMultiplier", + name: "Test 4: When gasMultiplier is not passed in root nor set in config", + want: core.DefaultGasMultiplier, + wantErr: nil, + }, + { + name: "Test 5: When gasMultiplier is out of a valid range", + useDummyConfigFile: true, args: args{ - gasMultiplierErr: errors.New("gasMultiplier error"), + gasMultiplierInTestConfig: 4, }, - want: 1, - wantErr: errors.New("gasMultiplier error"), + want: core.DefaultGasMultiplier, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "gasmultiplier", tt.args.gasMultiplierInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootFloat32GasMultiplier").Return(tt.args.gasMultiplier, tt.args.gasMultiplierErr) - utils := &UtilsStruct{} + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.gasMultiplier, tt.args.gasMultiplierErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetMultiplier() if got != tt.want { t.Errorf("getMultiplier() got = %v, want %v", got, tt.want) @@ -580,55 +770,67 @@ func TestGetMultiplier(t *testing.T) { func TestGetProvider(t *testing.T) { type args struct { - provider string - providerErr error + provider string + providerErr error + isFlagSet bool + providerInTestConfig string } tests := []struct { - name string - args args - want string - wantErr error + name string + args args + useDummyConfigFile bool + want string + wantErr error }{ { - name: "Test 1: When getProvider function execute successfully", + name: "Test 1: When provider is fetched from root flag", args: args{ - provider: "https://polygon-mumbai.g.alchemy.com/v2/-Re1lE3oDIVTWchuKMfRIECn0I", + provider: "https://polygon-mumbai.g.alchemy.com/v2/-Re1lE3oDIVTWchuKMfRIECn0I", + isFlagSet: true, }, want: "https://polygon-mumbai.g.alchemy.com/v2/-Re1lE3oDIVTWchuKMfRIECn0I", wantErr: nil, }, { - name: "Test 2: When provider has prefix https", - args: args{ - provider: "127.0.0.1:8545", - }, - want: "127.0.0.1:8545", - wantErr: nil, - }, - { - name: "Test 3: When there is an error in getting provider", + name: "Test 2: When there is an error in fetching provider from root flag", args: args{ providerErr: errors.New("provider error"), + isFlagSet: true, }, want: "", wantErr: errors.New("provider error"), }, { - name: "Test 4: When provider is nil", + name: "Test 3: When provider value is fetched from config", + useDummyConfigFile: true, args: args{ - provider: "", + providerInTestConfig: "https://config-provider-url.com", }, + want: "https://config-provider-url.com", + wantErr: nil, + }, + { + name: "Test 4: When provider is neither passed in root nor set in config", + args: args{}, want: "", wantErr: errors.New("provider is not set"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "provider", tt.args.providerInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootStringProvider").Return(tt.args.provider, tt.args.providerErr) - utils := &UtilsStruct{} + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.provider, tt.args.providerErr) + flagSetMock.On("Changed", mock.Anything, "provider").Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetProvider() if got != tt.want { t.Errorf("getProvider() got = %v, want %v", got, tt.want) @@ -648,55 +850,75 @@ func TestGetProvider(t *testing.T) { func TestGetAlternateProvider(t *testing.T) { type args struct { - alternateProvider string - alternateProviderErr error + isFlagSet bool + alternateProvider string + alternateProviderErr error + alternateProviderInConfig string } tests := []struct { - name string - args args - want string - wantErr error + name string + useDummyConfigFile bool + args args + want string + wantErr error }{ { - name: "Test 1: When getAlternateProvider function execute successfully", + name: "Test 1: When alternateProvider is fetched from root flag", args: args{ + isFlagSet: true, alternateProvider: "https://polygon-mumbai.g.alchemy.com/v2/-Re1lE3oDIVTWchuKMfRIECn0I", }, want: "https://polygon-mumbai.g.alchemy.com/v2/-Re1lE3oDIVTWchuKMfRIECn0I", wantErr: nil, }, { - name: "Test 2: When alternate provider has prefix https", + name: "Test 2: When alternateProvider from root flag has prefix https", args: args{ + isFlagSet: true, alternateProvider: "127.0.0.1:8545", }, want: "127.0.0.1:8545", wantErr: nil, }, { - name: "Test 3: When there is an error in getting alternate provider", + name: "Test 3: When there is an error in fetching alternateProvider from root flag", args: args{ + isFlagSet: true, alternateProviderErr: errors.New("alternateProvider error"), }, want: "", wantErr: errors.New("alternateProvider error"), }, { - name: "Test 4: When alternate provider is nil", + name: "Test 4: When alternateProvider value is fetched from config", + useDummyConfigFile: true, args: args{ - alternateProvider: "", + alternateProviderInConfig: "https://some-config-provider.com", }, + want: "https://some-config-provider.com", + wantErr: nil, + }, + { + name: "Test 5: When alternateProvider is not passed in root nor set in config", want: "", wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "alternateProvider", tt.args.alternateProviderInConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootStringAlternateProvider").Return(tt.args.alternateProvider, tt.args.alternateProviderErr) - utils := &UtilsStruct{} + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.alternateProvider, tt.args.alternateProviderErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} got, err := utils.GetAlternateProvider() if got != tt.want { t.Errorf("getAlternateProvider() got = %v, want %v", got, tt.want) @@ -714,177 +936,264 @@ func TestGetAlternateProvider(t *testing.T) { } } -func TestGetWaitTime(t *testing.T) { +func TestGetRPCTimeout(t *testing.T) { type args struct { - waitTime int32 - waitTimeErr error + isFlagSet bool + rpcTimeout int64 + rpcTimeoutErr error + rpcTimeoutInTestConfig int64 } tests := []struct { - name string - args args - want int32 - wantErr error + name string + useDummyConfigFile bool + args args + want int64 + wantErr error }{ { - name: "Test 1: When getWaitTime function executes successfully", + name: "Test 1: When rpcTimeout is fetched from root flag", args: args{ - waitTime: 4, + isFlagSet: true, + rpcTimeout: 12, }, - want: 4, + want: 12, wantErr: nil, }, { - name: "Test 2: When waitTime is -1", + name: "Test 2: When there is an error in fetching rpcTimeout from root flag", args: args{ - waitTime: -1, + isFlagSet: true, + rpcTimeoutErr: errors.New("rpcTimeout error"), }, - want: 1, + want: core.DefaultRPCTimeout, + wantErr: errors.New("rpcTimeout error"), + }, + { + name: "Test 3: When rpcTimeout value is fetched from config", + useDummyConfigFile: true, + args: args{ + rpcTimeoutInTestConfig: 20, + }, + want: 20, wantErr: nil, }, { - name: "Test 3: When there is an error in getting waitTime", + name: "Test 4: When rpcTimeout is not passed in root nor set in config", + want: core.DefaultRPCTimeout, + wantErr: nil, + }, + { + name: "Test 5: When rpcTimeout value is out of a valid range", + useDummyConfigFile: true, args: args{ - waitTimeErr: errors.New("waitTime error"), + rpcTimeoutInTestConfig: 70, }, - want: 1, - wantErr: errors.New("waitTime error"), + want: core.DefaultRPCTimeout, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "rpcTimeout", tt.args.rpcTimeoutInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootInt32Wait").Return(tt.args.waitTime, tt.args.waitTimeErr) + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.rpcTimeout, tt.args.rpcTimeoutErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} - got, err := utils.GetWaitTime() + got, err := utils.GetRPCTimeout() if got != tt.want { - t.Errorf("getWaitTime() got = %v, want %v", got, tt.want) + t.Errorf("getRPCTimeout() got = %v, want %v", got, tt.want) } if err == nil || tt.wantErr == nil { if err != tt.wantErr { - t.Errorf("Error for getWaitTime function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for getRPCTimeout function, got = %v, want = %v", err, tt.wantErr) } } else { if err.Error() != tt.wantErr.Error() { - t.Errorf("Error for getWaitTime function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for getRPCTimeout function, got = %v, want = %v", err, tt.wantErr) } } }) } } -func TestGetRPCTimeout(t *testing.T) { +func TestGetHTTPTimeout(t *testing.T) { type args struct { - rpcTimeout int64 - rpcTimeoutErr error + isFlagSet bool + httpTimeout int64 + httpTimeoutErr error + httpTimeoutInTestConfig int64 } tests := []struct { - name string - args args - want int64 - wantErr error + name string + useDummyConfigFile bool + args args + want int64 + wantErr error }{ { - name: "Test 1: When getRPCTimeout function executes successfully", + name: "Test 1: When httpTimeout is fetched from root flag", args: args{ - rpcTimeout: 12, + isFlagSet: true, + httpTimeout: 12, }, want: 12, wantErr: nil, }, { - name: "Test 2: When rpcTimeout is 0", + name: "Test 2: When there is an error in fetching httpTimeout from root flag", + args: args{ + isFlagSet: true, + httpTimeoutErr: errors.New("httpTimeout error"), + }, + want: core.DefaultHTTPTimeout, + wantErr: errors.New("httpTimeout error"), + }, + { + name: "Test 3: When httpTimeout value is fetched from config", + useDummyConfigFile: true, args: args{ - rpcTimeout: 0, + httpTimeoutInTestConfig: 20, }, - want: 10, + want: 20, wantErr: nil, }, { - name: "Test 3: When there is an error in getting rpcTimeout", + name: "Test 4: When httpTimeout is not passed in root nor set in config", + want: core.DefaultHTTPTimeout, + wantErr: nil, + }, + { + name: "Test 5: When httpTimeout is out of valid range", + useDummyConfigFile: true, args: args{ - rpcTimeoutErr: errors.New("rpcTimeout error"), + httpTimeoutInTestConfig: 70, }, - want: 10, - wantErr: errors.New("rpcTimeout error"), + want: core.DefaultHTTPTimeout, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "httpTimeout", tt.args.httpTimeoutInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootInt64RPCTimeout").Return(tt.args.rpcTimeout, tt.args.rpcTimeoutErr) + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.httpTimeout, tt.args.httpTimeoutErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} - got, err := utils.GetRPCTimeout() + got, err := utils.GetHTTPTimeout() if got != tt.want { - t.Errorf("getRPCTimeout() got = %v, want %v", got, tt.want) + t.Errorf("getHTTPTimeout() got = %v, want %v", got, tt.want) } if err == nil || tt.wantErr == nil { if err != tt.wantErr { - t.Errorf("Error for getRPCTimeout function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for getHTTPTimeout function, got = %v, want = %v", err, tt.wantErr) } } else { if err.Error() != tt.wantErr.Error() { - t.Errorf("Error for getRPCTimeout function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for getHTTPTimeout function, got = %v, want = %v", err, tt.wantErr) } } }) } } -func TestGetHTTPTimeout(t *testing.T) { +func TestGetWaitTime(t *testing.T) { type args struct { - httpTimeout int64 - httpTimeoutErr error + isFlagSet bool + waitTime int32 + waitTimeErr error + waitInTestConfig int32 } tests := []struct { - name string - args args - want int64 - wantErr error + name string + useDummyConfigFile bool + args args + want int32 + wantErr error }{ { - name: "Test 1: When getHTTPTimeout function executes successfully", + name: "Test 1: When wait time is fetched from root flag", args: args{ - httpTimeout: 12, + isFlagSet: true, + waitTime: 2, }, - want: 12, + want: 2, wantErr: nil, }, { - name: "Test 2: When httpTimeout is 0", + name: "Test 2: When there is an error in fetching wait time from root flag", args: args{ - httpTimeout: 0, + isFlagSet: true, + waitTimeErr: errors.New("wait time error"), }, - want: 10, + want: core.DefaultWaitTime, + wantErr: errors.New("wait time error"), + }, + { + name: "Test 3: When wait time value is fetched from config", + useDummyConfigFile: true, + args: args{ + waitInTestConfig: 3, + }, + want: 3, + wantErr: nil, + }, + { + name: "Test 4: When wait time is not passed in root nor set in config", + want: core.DefaultWaitTime, wantErr: nil, }, { - name: "Test 3: When there is an error in getting httpTimeout", + name: "Test 5: When wait time value is out of valid range", + useDummyConfigFile: true, args: args{ - httpTimeoutErr: errors.New("httpTimeout error"), + waitInTestConfig: 40, }, - want: 10, - wantErr: errors.New("httpTimeout error"), + want: core.DefaultWaitTime, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + viper.Reset() // Reset viper state + + if tt.useDummyConfigFile { + createTestConfig(t, "wait", tt.args.waitInTestConfig) + defer removeTestConfig(tempConfigPath) + } + SetUpMockInterfaces() - flagSetMock.On("GetRootInt64HTTPTimeout").Return(tt.args.httpTimeout, tt.args.httpTimeoutErr) + flagSetMock.On("FetchRootFlagInput", mock.Anything, mock.Anything).Return(tt.args.waitTime, tt.args.waitTimeErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagSet) + utils := &UtilsStruct{} - got, err := utils.GetHTTPTimeout() + got, err := utils.GetWaitTime() if got != tt.want { - t.Errorf("getHTTPTimeout() got = %v, want %v", got, tt.want) + t.Errorf("GetWaitTime() got = %v, want %v", got, tt.want) } if err == nil || tt.wantErr == nil { if err != tt.wantErr { - t.Errorf("Error for getHTTPTimeout function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for GetWaitTime function, got = %v, want = %v", err, tt.wantErr) } } else { if err.Error() != tt.wantErr.Error() { - t.Errorf("Error for getHTTPTimeout function, got = %v, want = %v", err, tt.wantErr) + t.Errorf("Error for GetWaitTime function, got = %v, want = %v", err, tt.wantErr) } } }) diff --git a/cmd/confirm.go b/cmd/confirm.go index 000cc3616..e9d308b04 100644 --- a/cmd/confirm.go +++ b/cmd/confirm.go @@ -29,7 +29,7 @@ func (*UtilsStruct) ClaimBlockReward(options types.TransactionOptions) (common.H return core.NilHash, nil } - stakerID, err := razorUtils.GetStakerId(options.Client, options.AccountAddress) + stakerID, err := razorUtils.GetStakerId(options.Client, options.Account.Address) if err != nil { log.Error("Error in getting stakerId: ", err) return core.NilHash, err diff --git a/cmd/confirm_test.go b/cmd/confirm_test.go index 1079c14aa..cd7977f1a 100644 --- a/cmd/confirm_test.go +++ b/cmd/confirm_test.go @@ -1,10 +1,7 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/mock" "math/big" "razor/core" @@ -12,7 +9,6 @@ import ( "razor/pkg/bindings" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" ) @@ -20,9 +16,6 @@ import ( func TestClaimBlockReward(t *testing.T) { var options types.TransactionOptions - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { epoch uint32 epochErr error @@ -32,7 +25,6 @@ func TestClaimBlockReward(t *testing.T) { sortedProposedBlockIdsErr error selectedBlock bindings.StructsBlock selectedBlockErr error - txnOpts *bind.TransactOpts ClaimBlockRewardTxn *Types.Transaction ClaimBlockRewardErr error hash common.Hash @@ -50,7 +42,6 @@ func TestClaimBlockReward(t *testing.T) { stakerId: 2, sortedProposedBlockIds: []uint32{2, 1, 3}, selectedBlock: bindings.StructsBlock{ProposerId: 2}, - txnOpts: txnOpts, ClaimBlockRewardTxn: &Types.Transaction{}, ClaimBlockRewardErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -65,7 +56,6 @@ func TestClaimBlockReward(t *testing.T) { stakerId: 2, sortedProposedBlockIds: []uint32{2, 1, 3}, selectedBlock: bindings.StructsBlock{ProposerId: 2}, - txnOpts: txnOpts, ClaimBlockRewardTxn: &Types.Transaction{}, ClaimBlockRewardErr: errors.New("claimBlockReward error"), hash: common.BigToHash(big.NewInt(1)), @@ -119,7 +109,6 @@ func TestClaimBlockReward(t *testing.T) { stakerId: 3, sortedProposedBlockIds: []uint32{2, 1, 3}, selectedBlock: bindings.StructsBlock{ProposerId: 2}, - txnOpts: txnOpts, ClaimBlockRewardTxn: &Types.Transaction{}, ClaimBlockRewardErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -144,7 +133,7 @@ func TestClaimBlockReward(t *testing.T) { utilsMock.On("GetSortedProposedBlockIds", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.sortedProposedBlockIds, tt.args.sortedProposedBlockIdsErr) utilsMock.On("GetStakerId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("GetProposedBlock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(tt.args.selectedBlock, tt.args.selectedBlockErr) - utilsMock.On("GetTxnOpts", options).Return(tt.args.txnOpts) + utilsMock.On("GetTxnOpts", options).Return(TxnOpts) blockManagerMock.On("ClaimBlockReward", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts")).Return(tt.args.ClaimBlockRewardTxn, tt.args.ClaimBlockRewardErr) transactionMock.On("Hash", mock.AnythingOfType("*types.Transaction")).Return(tt.args.hash) diff --git a/cmd/create.go b/cmd/create.go index 639318672..71ca08a98 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -49,8 +49,14 @@ func (*UtilsStruct) Create(password string) (accounts.Account, error) { return accounts.Account{Address: common.Address{0x00}}, err } log.Debug("Create: .razor directory path: ", razorPath) + accountManager, err := razorUtils.AccountManagerForKeystore() + if err != nil { + log.Error("Error in getting accounts manager for keystore: ", err) + return accounts.Account{Address: common.Address{0x00}}, err + } + keystorePath := filepath.Join(razorPath, "keystore_files") - account := accountUtils.CreateAccount(keystorePath, password) + account := accountManager.CreateAccount(keystorePath, password) return account, nil } diff --git a/cmd/createCollection.go b/cmd/createCollection.go index be61d6509..d6c44cd93 100644 --- a/cmd/createCollection.go +++ b/cmd/createCollection.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -51,7 +52,12 @@ func (*UtilsStruct) ExecuteCreateCollection(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) name, err := flagSetUtils.GetStringName(flagSet) @@ -70,16 +76,14 @@ func (*UtilsStruct) ExecuteCreateCollection(flagSet *pflag.FlagSet) { utils.CheckError("Error in getting tolerance: ", err) collectionInput := types.CreateCollectionInput{ - Address: address, - Password: password, Power: power, Name: name, Aggregation: aggregation, JobIds: jobIdInUint, Tolerance: tolerance, + Account: account, } - log.Debugf("Calling CreateCollection() with argument collectionInput: %+v", collectionInput) txn, err := cmdUtils.CreateCollection(client, config, collectionInput) utils.CheckError("CreateCollection error: ", err) err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) @@ -97,14 +101,13 @@ func (*UtilsStruct) CreateCollection(client *ethclient.Client, config types.Conf } txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: collectionInput.Password, - AccountAddress: collectionInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.CollectionManagerAddress, MethodName: "createCollection", Parameters: []interface{}{collectionInput.Tolerance, collectionInput.Power, collectionInput.Aggregation, jobIds, collectionInput.Name}, ABI: bindings.CollectionManagerMetaData.ABI, + Account: collectionInput.Account, }) log.Debugf("Executing CreateCollection transaction with tolerance: %d, power = %d , aggregation = %d, jobIds = %v, name = %s", collectionInput.Tolerance, collectionInput.Power, collectionInput.Aggregation, jobIds, collectionInput.Name) txn, err := assetManagerUtils.CreateCollection(client, txnOpts, collectionInput.Tolerance, collectionInput.Power, collectionInput.Aggregation, jobIds, collectionInput.Name) diff --git a/cmd/createCollection_test.go b/cmd/createCollection_test.go index 742ee3b11..cb8cfcc08 100644 --- a/cmd/createCollection_test.go +++ b/cmd/createCollection_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -19,17 +16,12 @@ import ( ) func TestCreateCollection(t *testing.T) { - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var client *ethclient.Client var WaitForDisputeOrConfirmStateStatus uint32 var config types.Configurations var collectionInput types.CreateCollectionInput type args struct { - txnOpts *bind.TransactOpts jobIdUint8 []uint16 waitForAppropriateStateErr error createCollectionTxn *Types.Transaction @@ -45,7 +37,6 @@ func TestCreateCollection(t *testing.T) { { name: "Test 1: When CreateCollection function executes successfully", args: args{ - txnOpts: txnOpts, jobIdUint8: []uint16{1, 2}, createCollectionTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), @@ -56,7 +47,6 @@ func TestCreateCollection(t *testing.T) { { name: "Test 2: When there is an error in WaitForConfirmState", args: args{ - txnOpts: txnOpts, jobIdUint8: []uint16{1, 2}, waitForAppropriateStateErr: errors.New("waitForDisputeOrConfirmState error"), createCollectionTxn: &Types.Transaction{}, @@ -68,7 +58,6 @@ func TestCreateCollection(t *testing.T) { { name: "Test 3: When CreateCollection transaction fails", args: args{ - txnOpts: txnOpts, jobIdUint8: []uint16{1, 2}, createCollectionTxn: &Types.Transaction{}, createCollectionErr: errors.New("createCollection error"), @@ -83,7 +72,7 @@ func TestCreateCollection(t *testing.T) { SetUpMockInterfaces() utilsMock.On("ConvertUintArrayToUint16Array", mock.Anything).Return(tt.args.jobIdUint8) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("WaitForAppropriateState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.Anything).Return(WaitForDisputeOrConfirmStateStatus, tt.args.waitForAppropriateStateErr) assetManagerMock.On("CreateCollection", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.createCollectionTxn, tt.args.createCollectionErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -287,7 +276,8 @@ func TestExecuteCreateCollection(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetStringName", flagSet).Return(tt.args.name, tt.args.nameErr) flagSetMock.On("GetUintSliceJobIds", flagSet).Return(tt.args.jobId, tt.args.jobIdErr) diff --git a/cmd/createJob.go b/cmd/createJob.go index 30a399fda..df93330cf 100644 --- a/cmd/createJob.go +++ b/cmd/createJob.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -52,7 +53,12 @@ func (*UtilsStruct) ExecuteCreateJob(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) name, err := flagSetUtils.GetStringName(flagSet) @@ -74,16 +80,15 @@ func (*UtilsStruct) ExecuteCreateJob(flagSet *pflag.FlagSet) { utils.CheckError("Error in getting selectorType: ", err) jobInput := types.CreateJobInput{ - Address: address, - Password: password, Url: url, Name: name, Selector: selector, SelectorType: selectorType, Weight: weight, Power: power, + Account: account, } - log.Debugf("ExecuteCreateJob: Calling CreateJob() with argument jobInput: %+v", jobInput) + txn, err := cmdUtils.CreateJob(client, config, jobInput) utils.CheckError("CreateJob error: ", err) err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) @@ -94,14 +99,13 @@ func (*UtilsStruct) ExecuteCreateJob(flagSet *pflag.FlagSet) { func (*UtilsStruct) CreateJob(client *ethclient.Client, config types.Configurations, jobInput types.CreateJobInput) (common.Hash, error) { txnArgs := types.TransactionOptions{ Client: client, - Password: jobInput.Password, - AccountAddress: jobInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.CollectionManagerAddress, MethodName: "createJob", Parameters: []interface{}{jobInput.Weight, jobInput.Power, jobInput.SelectorType, jobInput.Name, jobInput.Selector, jobInput.Url}, ABI: bindings.CollectionManagerMetaData.ABI, + Account: jobInput.Account, } txnOpts := razorUtils.GetTxnOpts(txnArgs) diff --git a/cmd/createJob_test.go b/cmd/createJob_test.go index 9286c3db6..81bb56119 100644 --- a/cmd/createJob_test.go +++ b/cmd/createJob_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -23,11 +20,7 @@ func TestCreateJob(t *testing.T) { var jobInput types.CreateJobInput var config types.Configurations - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { - txnOpts *bind.TransactOpts createJobTxn *Types.Transaction createJobErr error hash common.Hash @@ -41,7 +34,6 @@ func TestCreateJob(t *testing.T) { { name: "Test 1: When createJob function executes successfully", args: args{ - txnOpts: txnOpts, createJobTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -51,7 +43,6 @@ func TestCreateJob(t *testing.T) { { name: "Test 2: When createJob transaction fails", args: args{ - txnOpts: txnOpts, createJobTxn: &Types.Transaction{}, createJobErr: errors.New("createJob error"), hash: common.BigToHash(big.NewInt(1)), @@ -64,7 +55,7 @@ func TestCreateJob(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) assetManagerMock.On("CreateJob", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.createJobTxn, tt.args.createJobErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -307,7 +298,8 @@ func TestExecuteCreateJob(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetStringName", flagSet).Return(tt.args.name, tt.args.nameErr) flagSetMock.On("GetStringUrl", flagSet).Return(tt.args.url, tt.args.urlErr) diff --git a/cmd/create_test.go b/cmd/create_test.go index 2ddd5a02d..cb7bf36e9 100644 --- a/cmd/create_test.go +++ b/cmd/create_test.go @@ -2,7 +2,10 @@ package cmd import ( "errors" + accountsPkgMocks "razor/accounts/mocks" "razor/core/types" + pathPkgMocks "razor/path/mocks" + utilsPkgMocks "razor/utils/mocks" "testing" "github.com/ethereum/go-ethereum/accounts" @@ -14,10 +17,15 @@ import ( func TestCreate(t *testing.T) { var password string + nilAccount := accounts.Account{Address: common.Address{0x00}, + URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, + } + type args struct { - path string - pathErr error - account accounts.Account + path string + pathErr error + accountManagerErr error + account accounts.Account } tests := []struct { name string @@ -48,21 +56,40 @@ func TestCreate(t *testing.T) { URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, }, }, - want: accounts.Account{Address: common.Address{0x00}, - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }, + want: nilAccount, wantErr: errors.New("path error"), }, + { + name: "Test 3: When there is an error in getting account manager", + args: args{ + path: "/home/local", + pathErr: nil, + accountManagerErr: errors.New("account manager error"), + account: accounts.Account{Address: common.HexToAddress("0x000000000000000000000000000000000000dea1"), + URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, + }, + }, + want: nilAccount, + wantErr: errors.New("account manager error"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetUpMockInterfaces() + accountsMock := new(accountsPkgMocks.AccountManagerInterface) + + var pathMock *pathPkgMocks.PathInterface + var utilsMock *utilsPkgMocks.Utils + + pathMock = new(pathPkgMocks.PathInterface) + pathUtils = pathMock + + utilsMock = new(utilsPkgMocks.Utils) + razorUtils = utilsMock pathMock.On("GetDefaultPath").Return(tt.args.path, tt.args.pathErr) - accountsMock.On("CreateAccount", mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(accounts.Account{ - Address: tt.args.account.Address, - URL: accounts.URL{Scheme: "TestKeyScheme", Path: "test/key/path"}, - }) + utilsMock.On("AccountManagerForKeystore").Return(accountsMock, tt.args.accountManagerErr) + + accountsMock.On("CreateAccount", mock.Anything, mock.Anything).Return(tt.args.account) utils := &UtilsStruct{} got, err := utils.Create(password) @@ -131,7 +158,7 @@ func TestExecuteCreate(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) utilsMock.On("AssignPassword", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) cmdUtilsMock.On("Create", mock.AnythingOfType("string")).Return(tt.args.account, tt.args.accountErr) cmdUtilsMock.On("GetConfigData").Return(types.Configurations{}, nil) diff --git a/cmd/delegate.go b/cmd/delegate.go index 9645906ff..f14ac0b6d 100644 --- a/cmd/delegate.go +++ b/cmd/delegate.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -49,7 +50,12 @@ func (*UtilsStruct) ExecuteDelegate(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) stakerId, err := flagSetUtils.GetUint32StakerId(flagSet) @@ -70,15 +76,13 @@ func (*UtilsStruct) ExecuteDelegate(flagSet *pflag.FlagSet) { razorUtils.CheckEthBalanceIsZero(client, address) txnArgs := types.TransactionOptions{ - Client: client, - Password: password, - Amount: valueInWei, - AccountAddress: address, - ChainId: core.ChainId, - Config: config, + Client: client, + Amount: valueInWei, + ChainId: core.ChainId, + Config: config, + Account: account, } - log.Debugf("ExecuteDelegate: Calling Approve() with transaction arguments: %+v", txnArgs) approveTxnHash, err := cmdUtils.Approve(txnArgs) utils.CheckError("Approve error: ", err) diff --git a/cmd/delegate_test.go b/cmd/delegate_test.go index 20a743a2c..fe6ac8811 100644 --- a/cmd/delegate_test.go +++ b/cmd/delegate_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -19,14 +16,10 @@ import ( ) func TestDelegate(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var stakerId uint32 = 1 type args struct { amount *big.Int - txnOpts *bind.TransactOpts delegateTxn *Types.Transaction delegateErr error hash common.Hash @@ -41,7 +34,6 @@ func TestDelegate(t *testing.T) { name: "Test 1: When delegate function executes successfully", args: args{ amount: big.NewInt(1000), - txnOpts: txnOpts, delegateTxn: &Types.Transaction{}, delegateErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -53,7 +45,6 @@ func TestDelegate(t *testing.T) { name: "Test 2: When delegate transaction fails", args: args{ amount: big.NewInt(1000), - txnOpts: txnOpts, delegateTxn: &Types.Transaction{}, delegateErr: errors.New("delegate error"), hash: common.BigToHash(big.NewInt(1)), @@ -66,7 +57,7 @@ func TestDelegate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) stakeManagerMock.On("Delegate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.delegateTxn, tt.args.delegateErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -249,7 +240,8 @@ func TestExecuteDelegate(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetUint32StakerId", flagSet).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) diff --git a/cmd/dispute.go b/cmd/dispute.go index ae662c917..516cef698 100644 --- a/cmd/dispute.go +++ b/cmd/dispute.go @@ -36,7 +36,7 @@ func (*UtilsStruct) HandleDispute(client *ethclient.Client, config types.Configu } log.Debug("HandleDispute: SortedProposedBlockIds: ", sortedProposedBlockIds) - biggestStake, biggestStakerId, err := cmdUtils.GetBiggestStakeAndId(client, account.Address, epoch) + biggestStake, biggestStakerId, err := cmdUtils.GetBiggestStakeAndId(client, epoch) if err != nil { return err } @@ -57,11 +57,10 @@ func (*UtilsStruct) HandleDispute(client *ethclient.Client, config types.Configu randomSortedProposedBlockIds := utils.Shuffle(sortedProposedBlockIds) //shuffles the sortedProposedBlockIds array transactionOptions := types.TransactionOptions{ - Client: client, - Password: account.Password, - AccountAddress: account.Address, - ChainId: core.ChainId, - Config: config, + Client: client, + ChainId: core.ChainId, + Config: config, + Account: account, } log.Debug("HandleDispute: Shuffled sorted proposed blocks: ", randomSortedProposedBlockIds) @@ -88,11 +87,10 @@ func (*UtilsStruct) HandleDispute(client *ethclient.Client, config types.Configu log.Warn("PROPOSED BIGGEST STAKE DOES NOT MATCH WITH ACTUAL BIGGEST STAKE") log.Info("Disputing BiggestStakeProposed...") txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ - Client: client, - Password: account.Password, - AccountAddress: account.Address, - ChainId: core.ChainId, - Config: config, + Client: client, + ChainId: core.ChainId, + Config: config, + Account: account, }) log.Debugf("Executing DisputeBiggestStakeProposed transaction with arguments epoch = %d, blockIndex = %d, biggest staker Id = %d", epoch, blockIndex, biggestStakerId) disputeBiggestStakeProposedTxn, err := blockManagerUtils.DisputeBiggestStakeProposed(client, txnOpts, epoch, uint8(blockIndex), biggestStakerId) @@ -324,11 +322,10 @@ func (*UtilsStruct) Dispute(client *ethclient.Client, config types.Configuration blockManager := razorUtils.GetBlockManager(client) txnArgs := types.TransactionOptions{ - Client: client, - Password: account.Password, - AccountAddress: account.Address, - ChainId: core.ChainId, - Config: config, + Client: client, + ChainId: core.ChainId, + Config: config, + Account: account, } if !utils.Contains(giveSortedLeafIds, leafId) { @@ -429,7 +426,7 @@ func GiveSorted(client *ethclient.Client, blockManager *bindings.BlockManager, t } callOpts := razorUtils.GetOptions() txnOpts := razorUtils.GetTxnOpts(txnArgs) - disputesMapping, err := blockManagerUtils.Disputes(client, &callOpts, epoch, common.HexToAddress(txnArgs.AccountAddress)) + disputesMapping, err := blockManagerUtils.Disputes(client, &callOpts, epoch, common.HexToAddress(txnArgs.Account.Address)) if err != nil { log.Error("Error in getting disputes mapping: ", disputesMapping) return err diff --git a/cmd/dispute_test.go b/cmd/dispute_test.go index 141fbde93..2a473f238 100644 --- a/cmd/dispute_test.go +++ b/cmd/dispute_test.go @@ -22,9 +22,6 @@ import ( ) func TestDispute(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31337)) - var ( client *ethclient.Client config types.Configurations @@ -96,7 +93,7 @@ func TestDispute(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetBlockManager", mock.AnythingOfType("*ethclient.Client")).Return(blockManager) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("GiveSorted", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) cmdUtilsMock.On("GetCollectionIdPositionInBlock", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.positionOfCollectionInBlock) blockManagerMock.On("FinalizeDispute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.finalizeDisputeTxn, tt.args.finalizeDisputeErr) @@ -528,7 +525,7 @@ func TestHandleDispute(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetSortedProposedBlockIds", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.sortedProposedBlockIds, tt.args.sortedProposedBlockIdsErr) - cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32")).Return(tt.args.biggestStake, tt.args.biggestStakeId, tt.args.biggestStakeErr) + cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.biggestStake, tt.args.biggestStakeId, tt.args.biggestStakeErr) cmdUtilsMock.On("GetLocalMediansData", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(types.ProposeFileData{ MediansData: tt.args.medians, RevealedCollectionIds: tt.args.revealedCollectionIds, @@ -1206,7 +1203,7 @@ func BenchmarkHandleDispute(b *testing.B) { BiggestStake: big.NewInt(1).Mul(big.NewInt(5356), big.NewInt(1e18))} utilsMock.On("GetSortedProposedBlockIds", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(getUint32DummyIds(v.numOfSortedBlocks), nil) - cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32")).Return(big.NewInt(1).Mul(big.NewInt(5356), big.NewInt(1e18)), uint32(2), nil) + cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(big.NewInt(1).Mul(big.NewInt(5356), big.NewInt(1e18)), uint32(2), nil) cmdUtilsMock.On("GetLocalMediansData", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(proposedData, nil) utilsMock.On("GetProposedBlock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(proposedBlock, nil) utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) diff --git a/cmd/eventListeners.go b/cmd/eventListeners.go new file mode 100644 index 000000000..7b8d876fa --- /dev/null +++ b/cmd/eventListeners.go @@ -0,0 +1,135 @@ +package cmd + +import ( + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + Types "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethclient" + "math/big" + "razor/cache" + "razor/core" + "razor/core/types" + "razor/pkg/bindings" + "razor/utils" + "strings" +) + +func (*UtilsStruct) InitJobAndCollectionCache(client *ethclient.Client) (*cache.JobsCache, *cache.CollectionsCache, *big.Int, error) { + initAssetCacheBlock, err := clientUtils.GetLatestBlockWithRetry(client) + if err != nil { + log.Error("Error in fetching block: ", err) + return nil, nil, nil, err + } + log.Debugf("InitJobAndCollectionCache: Latest header value when initializing jobs and collections cache: %d", initAssetCacheBlock.Number) + + log.Info("INITIALIZING JOBS AND COLLECTIONS CACHE...") + + // Create instances of cache + jobsCache := cache.NewJobsCache() + collectionsCache := cache.NewCollectionsCache() + + // Initialize caches + if err := utils.InitJobsCache(client, jobsCache); err != nil { + log.Error("Error in initializing jobs cache: ", err) + return nil, nil, nil, err + } + if err := utils.InitCollectionsCache(client, collectionsCache); err != nil { + log.Error("Error in initializing collections cache: ", err) + return nil, nil, nil, err + } + + return jobsCache, collectionsCache, initAssetCacheBlock.Number, nil +} + +// CheckForJobAndCollectionEvents checks for specific job and collections event that were emitted. +func CheckForJobAndCollectionEvents(client *ethclient.Client, commitParams *types.CommitParams) error { + collectionManagerContractABI, err := abi.JSON(strings.NewReader(bindings.CollectionManagerMetaData.ABI)) + if err != nil { + log.Errorf("Error in parsing collection manager contract ABI: %v", err) + return err + } + + eventNames := []string{core.JobUpdatedEvent, core.CollectionUpdatedEvent, core.CollectionActivityStatusEvent, core.JobCreatedEvent, core.CollectionCreatedEvent} + + log.Debug("Checking for Job/Collection update events...") + toBlock, err := clientUtils.GetLatestBlockWithRetry(client) + if err != nil { + log.Error("Error in getting latest block to start event listener: ", err) + return err + } + + // Process events and update the fromBlock for the next iteration + newFromBlock, err := processEvents(client, collectionManagerContractABI, commitParams.FromBlockToCheckForEvents, toBlock.Number, eventNames, commitParams.JobsCache, commitParams.CollectionsCache) + if err != nil { + return err + } + + // Update the commitParams with the new fromBlock + commitParams.FromBlockToCheckForEvents = new(big.Int).Add(newFromBlock, big.NewInt(1)) + + return nil +} + +// processEvents fetches and processes logs for multiple event types. +func processEvents(client *ethclient.Client, contractABI abi.ABI, fromBlock, toBlock *big.Int, eventNames []string, jobsCache *cache.JobsCache, collectionsCache *cache.CollectionsCache) (*big.Int, error) { + logs, err := getEventLogs(client, fromBlock, toBlock) + if err != nil { + log.Errorf("Failed to fetch logs: %v", err) + return nil, err + } + + for _, eventName := range eventNames { + eventID := contractABI.Events[eventName].ID.Hex() + for _, vLog := range logs { + if len(vLog.Topics) > 0 && vLog.Topics[0].Hex() == eventID { + switch eventName { + case core.JobUpdatedEvent, core.JobCreatedEvent: + jobId := utils.ConvertHashToUint16(vLog.Topics[1]) + updatedJob, err := utils.UtilsInterface.GetActiveJob(client, jobId) + if err != nil { + log.Errorf("Error in getting job with job Id %v: %v", jobId, err) + continue + } + log.Debugf("RECEIVED JOB EVENT: Updating the job with Id %v with details %+v...", jobId, updatedJob) + jobsCache.UpdateJob(jobId, updatedJob) + case core.CollectionUpdatedEvent, core.CollectionCreatedEvent, core.CollectionActivityStatusEvent: + collectionId := utils.ConvertHashToUint16(vLog.Topics[1]) + newCollection, err := utils.UtilsInterface.GetCollection(client, collectionId) + if err != nil { + log.Errorf("Error in getting collection with collection Id %v: %v", collectionId, err) + continue + } + log.Debugf("RECEIVED COLLECTION EVENT: Updating the collection with ID %v with details %+v", collectionId, newCollection) + collectionsCache.UpdateCollection(collectionId, newCollection) + } + } + } + } + + // Return the new toBlock for the next iteration + return toBlock, nil +} + +// getEventLogs is a utility function to fetch the event logs +func getEventLogs(client *ethclient.Client, fromBlock *big.Int, toBlock *big.Int) ([]Types.Log, error) { + log.Debugf("Checking for events from block %v to block %v...", fromBlock, toBlock) + + // Set up the query for filtering logs + query := ethereum.FilterQuery{ + FromBlock: fromBlock, + ToBlock: toBlock, + Addresses: []common.Address{ + common.HexToAddress(core.CollectionManagerAddress), + }, + } + + // Retrieve the logs + logs, err := clientUtils.FilterLogsWithRetry(client, query) + if err != nil { + log.Errorf("Error in filter logs: %v", err) + return []Types.Log{}, err + } + + return logs, nil +} diff --git a/cmd/test_utils.go b/cmd/initTestMocks_test.go similarity index 93% rename from cmd/test_utils.go rename to cmd/initTestMocks_test.go index afe5f44b2..f410e24a3 100644 --- a/cmd/test_utils.go +++ b/cmd/initTestMocks_test.go @@ -1,7 +1,11 @@ package cmd import ( - accountsPkgMocks "razor/accounts/mocks" + "crypto/ecdsa" + "crypto/rand" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/crypto" + "math/big" "razor/cmd/mocks" "razor/path" pathPkgMocks "razor/path/mocks" @@ -19,7 +23,6 @@ var ( ioUtilsMock *utilsPkgMocks.IOUtils abiUtilsMock *utilsPkgMocks.ABIUtils bindUtilsMock *utilsPkgMocks.BindUtils - accountUtilsMock *utilsPkgMocks.AccountsUtils blockManagerUtilsMock *utilsPkgMocks.BlockManagerUtils stakeManagerUtilsMock *utilsPkgMocks.StakeManagerUtils assetManagerUtilsMock *utilsPkgMocks.AssetManagerUtils @@ -47,7 +50,6 @@ var ( osMock *mocks.OSInterface pathMock *pathPkgMocks.PathInterface osPathMock *pathPkgMocks.OSInterface - accountsMock *accountsPkgMocks.AccountInterface ) func SetUpMockInterfaces() { @@ -84,9 +86,6 @@ func SetUpMockInterfaces() { bindUtilsMock = new(utilsPkgMocks.BindUtils) utils.BindingsInterface = bindingUtilsMock - accountUtilsMock = new(utilsPkgMocks.AccountsUtils) - utils.AccountsInterface = accountUtilsMock - blockManagerUtilsMock = new(utilsPkgMocks.BlockManagerUtils) utils.BlockManagerInterface = blockManagerUtilsMock @@ -164,7 +163,7 @@ func SetUpMockInterfaces() { osPathMock = new(pathPkgMocks.OSInterface) path.OSUtilsInterface = osPathMock - - accountsMock = new(accountsPkgMocks.AccountInterface) - accountUtils = accountsMock } + +var privateKey, _ = ecdsa.GenerateKey(crypto.S256(), rand.Reader) +var TxnOpts, _ = bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31000)) // Used any random big int for chain ID diff --git a/cmd/initiateWithdraw.go b/cmd/initiateWithdraw.go index 8568baedf..f1393c452 100644 --- a/cmd/initiateWithdraw.go +++ b/cmd/initiateWithdraw.go @@ -4,6 +4,7 @@ package cmd import ( "errors" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -50,7 +51,12 @@ func (*UtilsStruct) ExecuteInitiateWithdraw(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) stakerId, err := razorUtils.AssignStakerId(flagSet, client, address) @@ -58,12 +64,9 @@ func (*UtilsStruct) ExecuteInitiateWithdraw(flagSet *pflag.FlagSet) { log.Debug("ExecuteInitiateWithdraw: Staker Id: ", stakerId) log.Debugf("ExecuteInitiateWithdraw: Calling HandleUnstakeLock() with arguments account address: %s, stakerId: %d", address, stakerId) - txn, err := cmdUtils.HandleUnstakeLock(client, types.Account{ - Address: address, - Password: password, - }, config, stakerId) - + txn, err := cmdUtils.HandleUnstakeLock(client, account, config, stakerId) utils.CheckError("InitiateWithdraw error: ", err) + if txn != core.NilHash { err := razorUtils.WaitForBlockCompletion(client, txn.Hex()) utils.CheckError("Error in WaitForBlockCompletion for initiateWithdraw: ", err) @@ -121,14 +124,13 @@ func (*UtilsStruct) HandleUnstakeLock(client *ethclient.Client, account types.Ac txnArgs := types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: configurations, ContractAddress: core.StakeManagerAddress, MethodName: "initiateWithdraw", ABI: bindings.StakeManagerMetaData.ABI, Parameters: []interface{}{stakerId}, + Account: account, } txnOpts := razorUtils.GetTxnOpts(txnArgs) diff --git a/cmd/initiateWithdraw_test.go b/cmd/initiateWithdraw_test.go index f4fe02d31..ac827497f 100644 --- a/cmd/initiateWithdraw_test.go +++ b/cmd/initiateWithdraw_test.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" @@ -19,9 +20,6 @@ import ( ) func TestHandleUnstakeLock(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var client *ethclient.Client var account types.Account var configurations types.Configurations @@ -34,7 +32,6 @@ func TestHandleUnstakeLock(t *testing.T) { lockErr error withdrawReleasePeriod uint16 withdrawReleasePeriodErr error - txnOpts *bind.TransactOpts epoch uint32 epochErr error time string @@ -54,7 +51,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 5, withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -68,7 +64,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 5, stateErr: errors.New("error in getting state"), }, @@ -80,7 +75,6 @@ func TestHandleUnstakeLock(t *testing.T) { args: args{ lockErr: errors.New("lock error"), withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 5, withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -94,7 +88,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(0), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 5, withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -108,7 +101,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriodErr: errors.New("withdrawReleasePeriod error"), - txnOpts: txnOpts, epoch: 5, withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -122,7 +114,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epochErr: errors.New("epoch error"), withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -136,7 +127,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 9, withdrawHash: common.BigToHash(big.NewInt(1)), }, @@ -150,7 +140,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 3, time: "10 minutes 0 seconds ", withdrawHash: common.BigToHash(big.NewInt(1)), @@ -165,7 +154,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 1, - txnOpts: txnOpts, epoch: 5, withdrawErr: errors.New("withdraw error"), }, @@ -179,7 +167,6 @@ func TestHandleUnstakeLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, withdrawReleasePeriod: 4, - txnOpts: txnOpts, epoch: 2, time: "20 minutes 0 seconds ", withdrawHash: common.BigToHash(big.NewInt(1)), @@ -196,7 +183,7 @@ func TestHandleUnstakeLock(t *testing.T) { utilsMock.On("GetLock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32"), mock.Anything).Return(tt.args.lock, tt.args.lockErr) utilsMock.On("GetWithdrawInitiationPeriod", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.withdrawReleasePeriod, tt.args.withdrawReleasePeriodErr) utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.epochErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("InitiateWithdraw", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.withdrawHash, tt.args.withdrawErr) utilsMock.On("SecondsToReadableTime", mock.AnythingOfType("int")).Return(tt.args.time) @@ -377,7 +364,8 @@ func TestExecuteWithdraw(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) utilsMock.On("AssignStakerId", flagSet, mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) diff --git a/cmd/interface.go b/cmd/interface.go index a168b638e..51a1e127d 100644 --- a/cmd/interface.go +++ b/cmd/interface.go @@ -1,11 +1,11 @@ -//Package cmd provides all functions related to command line +// Package cmd provides all functions related to command line package cmd import ( "context" "crypto/ecdsa" "math/big" - Accounts "razor/accounts" + "razor/cache" "razor/core/types" "razor/path" "razor/pkg/bindings" @@ -110,32 +110,10 @@ type AssetManagerInterface interface { } type FlagSetInterface interface { - GetStringProvider(flagSet *pflag.FlagSet) (string, error) - GetStringAlternateProvider(flagSet *pflag.FlagSet) (string, error) - GetFloat32GasMultiplier(flagSet *pflag.FlagSet) (float32, error) - GetInt32Buffer(flagSet *pflag.FlagSet) (int32, error) - GetInt32Wait(flagSet *pflag.FlagSet) (int32, error) - GetInt32GasPrice(flagSet *pflag.FlagSet) (int32, error) - GetFloat32GasLimit(flagSet *pflag.FlagSet) (float32, error) - GetUint64GasLimitOverride(flagSet *pflag.FlagSet) (uint64, error) - GetStringLogLevel(flagSet *pflag.FlagSet) (string, error) - GetInt64RPCTimeout(flagSet *pflag.FlagSet) (int64, error) - GetInt64HTTPTimeout(flagSet *pflag.FlagSet) (int64, error) + FetchFlagInput(flagSet *pflag.FlagSet, flagKeyword string, dataType string) (interface{}, error) + FetchRootFlagInput(flagName string, dataType string) (interface{}, error) + Changed(flagSet *pflag.FlagSet, flagName string) bool GetUint32BountyId(flagSet *pflag.FlagSet) (uint32, error) - GetRootStringProvider() (string, error) - GetRootStringAlternateProvider() (string, error) - GetRootFloat32GasMultiplier() (float32, error) - GetRootInt32Buffer() (int32, error) - GetRootInt32Wait() (int32, error) - GetRootInt32GasPrice() (int32, error) - GetRootStringLogLevel() (string, error) - GetRootFloat32GasLimit() (float32, error) - GetRootUint64GasLimitOverride() (uint64, error) - GetRootInt64RPCTimeout() (int64, error) - GetRootInt64HTTPTimeout() (int64, error) - GetRootIntLogFileMaxSize() (int, error) - GetRootIntLogFileMaxBackups() (int, error) - GetRootIntLogFileMaxAge() (int, error) GetStringFrom(flagSet *pflag.FlagSet) (string, error) GetStringTo(flagSet *pflag.FlagSet) (string, error) GetStringAddress(flagSet *pflag.FlagSet) (string, error) @@ -187,14 +165,14 @@ type UtilsCmdInterface interface { ClaimBounty(config types.Configurations, client *ethclient.Client, redeemBountyInput types.RedeemBountyInput) (common.Hash, error) ClaimBlockReward(options types.TransactionOptions) (common.Hash, error) GetSalt(client *ethclient.Client, epoch uint32) ([32]byte, error) - HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, rogueData types.Rogue) (types.CommitData, error) - Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, seed []byte, values []*big.Int) (common.Hash, error) + HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, commitParams *types.CommitParams, rogueData types.Rogue) (types.CommitData, error) + Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *Types.Header, seed []byte, values []*big.Int) (common.Hash, error) ListAccounts() ([]accounts.Account, error) AssignAmountInWei(flagSet *pflag.FlagSet) (*big.Int, error) ExecuteTransfer(flagSet *pflag.FlagSet) Transfer(client *ethclient.Client, config types.Configurations, transferInput types.TransferInput) (common.Hash, error) CheckForLastCommitted(client *ethclient.Client, staker bindings.StructsStaker, epoch uint32) error - Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, commitData types.CommitData, signature []byte) (common.Hash, error) + Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *Types.Header, commitData types.CommitData, signature []byte) (common.Hash, error) GenerateTreeRevealData(merkleTree [][][]byte, commitData types.CommitData) bindings.StructsMerkleTree IndexRevealEventsOfCurrentEpoch(client *ethclient.Client, blockNumber *big.Int, epoch uint32) ([]types.RevealedStruct, error) ExecuteCreateJob(flagSet *pflag.FlagSet) @@ -229,7 +207,7 @@ type UtilsCmdInterface interface { IsElectedProposer(proposer types.ElectedProposer, currentStakerStake *big.Int) bool GetSortedRevealedValues(client *ethclient.Client, blockNumber *big.Int, epoch uint32) (*types.RevealedDataMaps, error) GetIteration(client *ethclient.Client, proposer types.ElectedProposer, bufferPercent int32) int - Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, blockNumber *big.Int, rogueData types.Rogue) error + Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, latestHeader *Types.Header, rogueData types.Rogue) error GiveSorted(client *ethclient.Client, blockManager *bindings.BlockManager, txnArgs types.TransactionOptions, epoch uint32, assetId uint16, sortedStakers []*big.Int) error GetLocalMediansData(client *ethclient.Client, account types.Account, epoch uint32, blockNumber *big.Int, rogueData types.Rogue) (types.ProposeFileData, error) CheckDisputeForIds(client *ethclient.Client, transactionOpts types.TransactionOptions, epoch uint32, blockIndex uint8, idsInProposedBlock []uint16, revealedCollectionIds []uint16) (*Types.Transaction, error) @@ -250,20 +228,20 @@ type UtilsCmdInterface interface { ImportAccount() (accounts.Account, error) ExecuteUpdateCommission(flagSet *pflag.FlagSet) UpdateCommission(config types.Configurations, client *ethclient.Client, updateCommissionInput types.UpdateCommissionInput) error - GetBiggestStakeAndId(client *ethclient.Client, address string, epoch uint32) (*big.Int, uint32, error) + GetBiggestStakeAndId(client *ethclient.Client, epoch uint32) (*big.Int, uint32, error) GetSmallestStakeAndId(client *ethclient.Client, epoch uint32) (*big.Int, uint32, error) StakeCoins(txnArgs types.TransactionOptions) (common.Hash, error) CalculateSecret(account types.Account, epoch uint32, keystorePath string, chainId *big.Int) ([]byte, []byte, error) - HandleBlock(client *ethclient.Client, account types.Account, blockNumber *big.Int, config types.Configurations, rogueData types.Rogue, backupNodeActionsToIgnore []string) + HandleBlock(client *ethclient.Client, account types.Account, stakerId uint32, header *Types.Header, config types.Configurations, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) ExecuteVote(flagSet *pflag.FlagSet) - Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, rogueData types.Rogue, account types.Account, backupNodeActionsToIgnore []string) error + Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, account types.Account, stakerId uint32, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) error HandleExit() ExecuteListAccounts(flagSet *pflag.FlagSet) ClaimCommission(flagSet *pflag.FlagSet) ExecuteStake(flagSet *pflag.FlagSet) - InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, rogueData types.Rogue) error - InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, rogueData types.Rogue) error - InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, blockNumber *big.Int, rogueData types.Rogue) error + InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, latestHeader *Types.Header, commitParams *types.CommitParams, rogueData types.Rogue) error + InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *Types.Header, rogueData types.Rogue) error + InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *Types.Header, rogueData types.Rogue) error GetBountyIdFromEvents(client *ethclient.Client, blockNumber *big.Int, bountyHunter string) (uint32, error) HandleClaimBounty(client *ethclient.Client, config types.Configurations, account types.Account) error ExecuteContractAddresses(flagSet *pflag.FlagSet) @@ -271,6 +249,8 @@ type UtilsCmdInterface interface { ResetDispute(client *ethclient.Client, blockManager *bindings.BlockManager, txnOpts *bind.TransactOpts, epoch uint32) StoreBountyId(client *ethclient.Client, account types.Account) error CheckToDoResetDispute(client *ethclient.Client, blockManager *bindings.BlockManager, txnOpts *bind.TransactOpts, epoch uint32, sortedValues []*big.Int) + InitJobAndCollectionCache(client *ethclient.Client) (*cache.JobsCache, *cache.CollectionsCache, *big.Int, error) + BatchGetStakeSnapshotCalls(client *ethclient.Client, epoch uint32, numberOfStakers uint32) ([]*big.Int, error) } type TransactionInterface interface { @@ -335,7 +315,6 @@ func InitializeInterfaces() { abiUtils = AbiUtils{} osUtils = OSUtils{} - Accounts.AccountUtilsInterface = Accounts.AccountUtils{} path.PathUtilsInterface = path.PathUtils{} path.OSUtilsInterface = path.OSUtils{} InitializeUtils() diff --git a/cmd/mocks/flag_set_interface.go b/cmd/mocks/flag_set_interface.go index 13b981258..e86159625 100644 --- a/cmd/mocks/flag_set_interface.go +++ b/cmd/mocks/flag_set_interface.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.30.1. DO NOT EDIT. package mocks @@ -12,62 +12,39 @@ type FlagSetInterface struct { mock.Mock } -// GetBoolRogue provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetBoolRogue(flagSet *pflag.FlagSet) (bool, error) { - ret := _m.Called(flagSet) +// Changed provides a mock function with given fields: flagSet, flagName +func (_m *FlagSetInterface) Changed(flagSet *pflag.FlagSet, flagName string) bool { + ret := _m.Called(flagSet, flagName) var r0 bool - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) bool); ok { - r0 = rf(flagSet) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet, string) bool); ok { + r0 = rf(flagSet, flagName) } else { r0 = ret.Get(0).(bool) } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } -// GetBoolWeiRazor provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetBoolWeiRazor(flagSet *pflag.FlagSet) (bool, error) { - ret := _m.Called(flagSet) - - var r0 bool - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) bool); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(bool) - } +// FetchFlagInput provides a mock function with given fields: flagSet, flagKeyword, dataType +func (_m *FlagSetInterface) FetchFlagInput(flagSet *pflag.FlagSet, flagKeyword string, dataType string) (interface{}, error) { + ret := _m.Called(flagSet, flagKeyword, dataType) + var r0 interface{} var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet, string, string) (interface{}, error)); ok { + return rf(flagSet, flagKeyword, dataType) } - - return r0, r1 -} - -// GetFloat32GasLimit provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetFloat32GasLimit(flagSet *pflag.FlagSet) (float32, error) { - ret := _m.Called(flagSet) - - var r0 float32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) float32); ok { - r0 = rf(flagSet) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet, string, string) interface{}); ok { + r0 = rf(flagSet, flagKeyword, dataType) } else { - r0 = ret.Get(0).(float32) + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) + if rf, ok := ret.Get(1).(func(*pflag.FlagSet, string, string) error); ok { + r1 = rf(flagSet, flagKeyword, dataType) } else { r1 = ret.Error(1) } @@ -75,41 +52,25 @@ func (_m *FlagSetInterface) GetFloat32GasLimit(flagSet *pflag.FlagSet) (float32, return r0, r1 } -// GetFloat32GasMultiplier provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetFloat32GasMultiplier(flagSet *pflag.FlagSet) (float32, error) { - ret := _m.Called(flagSet) - - var r0 float32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) float32); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(float32) - } +// FetchRootFlagInput provides a mock function with given fields: flagName, dataType +func (_m *FlagSetInterface) FetchRootFlagInput(flagName string, dataType string) (interface{}, error) { + ret := _m.Called(flagName, dataType) + var r0 interface{} var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(string, string) (interface{}, error)); ok { + return rf(flagName, dataType) } - - return r0, r1 -} - -// GetInt32Buffer provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetInt32Buffer(flagSet *pflag.FlagSet) (int32, error) { - ret := _m.Called(flagSet) - - var r0 int32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int32); ok { - r0 = rf(flagSet) + if rf, ok := ret.Get(0).(func(string, string) interface{}); ok { + r0 = rf(flagName, dataType) } else { - r0 = ret.Get(0).(int32) + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(flagName, dataType) } else { r1 = ret.Error(1) } @@ -117,39 +78,21 @@ func (_m *FlagSetInterface) GetInt32Buffer(flagSet *pflag.FlagSet) (int32, error return r0, r1 } -// GetInt32GasPrice provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetInt32GasPrice(flagSet *pflag.FlagSet) (int32, error) { +// GetBoolRogue provides a mock function with given fields: flagSet +func (_m *FlagSetInterface) GetBoolRogue(flagSet *pflag.FlagSet) (bool, error) { ret := _m.Called(flagSet) - var r0 int32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int32); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(int32) - } - + var r0 bool var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (bool, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetInt32Wait provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetInt32Wait(flagSet *pflag.FlagSet) (int32, error) { - ret := _m.Called(flagSet) - - var r0 int32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int32); ok { + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) bool); ok { r0 = rf(flagSet) } else { - r0 = ret.Get(0).(int32) + r0 = ret.Get(0).(bool) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -159,39 +102,21 @@ func (_m *FlagSetInterface) GetInt32Wait(flagSet *pflag.FlagSet) (int32, error) return r0, r1 } -// GetInt64HTTPTimeout provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetInt64HTTPTimeout(flagSet *pflag.FlagSet) (int64, error) { +// GetBoolWeiRazor provides a mock function with given fields: flagSet +func (_m *FlagSetInterface) GetBoolWeiRazor(flagSet *pflag.FlagSet) (bool, error) { ret := _m.Called(flagSet) - var r0 int64 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int64); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(int64) - } - + var r0 bool var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (bool, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetInt64RPCTimeout provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetInt64RPCTimeout(flagSet *pflag.FlagSet) (int64, error) { - ret := _m.Called(flagSet) - - var r0 int64 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int64); ok { + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) bool); ok { r0 = rf(flagSet) } else { - r0 = ret.Get(0).(int64) + r0 = ret.Get(0).(bool) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -206,13 +131,16 @@ func (_m *FlagSetInterface) GetInt8Power(flagSet *pflag.FlagSet) (int8, error) { ret := _m.Called(flagSet) var r0 int8 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (int8, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int8); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(int8) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -227,13 +155,16 @@ func (_m *FlagSetInterface) GetIntLogFileMaxAge(flagSet *pflag.FlagSet) (int, er ret := _m.Called(flagSet) var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (int, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(int) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -248,13 +179,16 @@ func (_m *FlagSetInterface) GetIntLogFileMaxBackups(flagSet *pflag.FlagSet) (int ret := _m.Called(flagSet) var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (int, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(int) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -269,13 +203,16 @@ func (_m *FlagSetInterface) GetIntLogFileMaxSize(flagSet *pflag.FlagSet) (int, e ret := _m.Called(flagSet) var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (int, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) int); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(int) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -285,333 +222,21 @@ func (_m *FlagSetInterface) GetIntLogFileMaxSize(flagSet *pflag.FlagSet) (int, e return r0, r1 } -// GetRootFloat32GasLimit provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootFloat32GasLimit() (float32, error) { - ret := _m.Called() - - var r0 float32 - if rf, ok := ret.Get(0).(func() float32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(float32) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootFloat32GasMultiplier provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootFloat32GasMultiplier() (float32, error) { - ret := _m.Called() - - var r0 float32 - if rf, ok := ret.Get(0).(func() float32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(float32) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootInt32Buffer provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootInt32Buffer() (int32, error) { - ret := _m.Called() - - var r0 int32 - if rf, ok := ret.Get(0).(func() int32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int32) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootInt32GasPrice provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootInt32GasPrice() (int32, error) { - ret := _m.Called() - - var r0 int32 - if rf, ok := ret.Get(0).(func() int32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int32) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootInt32Wait provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootInt32Wait() (int32, error) { - ret := _m.Called() - - var r0 int32 - if rf, ok := ret.Get(0).(func() int32); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int32) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootInt64HTTPTimeout provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootInt64HTTPTimeout() (int64, error) { - ret := _m.Called() - - var r0 int64 - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int64) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootInt64RPCTimeout provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootInt64RPCTimeout() (int64, error) { - ret := _m.Called() - - var r0 int64 - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int64) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootIntLogFileMaxAge provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootIntLogFileMaxAge() (int, error) { - ret := _m.Called() - - var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootIntLogFileMaxBackups provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootIntLogFileMaxBackups() (int, error) { - ret := _m.Called() - - var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootIntLogFileMaxSize provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootIntLogFileMaxSize() (int, error) { - ret := _m.Called() - - var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootStringAlternateProvider provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootStringAlternateProvider() (string, error) { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootStringLogLevel provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootStringLogLevel() (string, error) { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootStringProvider provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootStringProvider() (string, error) { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetRootUint64GasLimitOverride provides a mock function with given fields: -func (_m *FlagSetInterface) GetRootUint64GasLimitOverride() (uint64, error) { - ret := _m.Called() - - var r0 uint64 - if rf, ok := ret.Get(0).(func() uint64); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(uint64) - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetStringAddress provides a mock function with given fields: flagSet func (_m *FlagSetInterface) GetStringAddress(flagSet *pflag.FlagSet) (string, error) { ret := _m.Called(flagSet) var r0 string - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(string) - } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetStringAlternateProvider provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetStringAlternateProvider(flagSet *pflag.FlagSet) (string, error) { - ret := _m.Called(flagSet) - - var r0 string if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -626,13 +251,16 @@ func (_m *FlagSetInterface) GetStringCertFile(flagSet *pflag.FlagSet) (string, e ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -647,13 +275,16 @@ func (_m *FlagSetInterface) GetStringCertKey(flagSet *pflag.FlagSet) (string, er ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -668,13 +299,16 @@ func (_m *FlagSetInterface) GetStringExposeMetrics(flagSet *pflag.FlagSet) (stri ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -689,34 +323,16 @@ func (_m *FlagSetInterface) GetStringFrom(flagSet *pflag.FlagSet) (string, error ret := _m.Called(flagSet) var r0 string - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(string) - } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetStringLogLevel provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetStringLogLevel(flagSet *pflag.FlagSet) (string, error) { - ret := _m.Called(flagSet) - - var r0 string if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -731,34 +347,16 @@ func (_m *FlagSetInterface) GetStringName(flagSet *pflag.FlagSet) (string, error ret := _m.Called(flagSet) var r0 string - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(string) - } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetStringProvider provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetStringProvider(flagSet *pflag.FlagSet) (string, error) { - ret := _m.Called(flagSet) - - var r0 string if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -773,13 +371,16 @@ func (_m *FlagSetInterface) GetStringSelector(flagSet *pflag.FlagSet) (string, e ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -794,6 +395,10 @@ func (_m *FlagSetInterface) GetStringSliceBackupNode(flagSet *pflag.FlagSet) ([] ret := _m.Called(flagSet) var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) ([]string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) []string); ok { r0 = rf(flagSet) } else { @@ -802,7 +407,6 @@ func (_m *FlagSetInterface) GetStringSliceBackupNode(flagSet *pflag.FlagSet) ([] } } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -817,6 +421,10 @@ func (_m *FlagSetInterface) GetStringSliceRogueMode(flagSet *pflag.FlagSet) ([]s ret := _m.Called(flagSet) var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) ([]string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) []string); ok { r0 = rf(flagSet) } else { @@ -825,7 +433,6 @@ func (_m *FlagSetInterface) GetStringSliceRogueMode(flagSet *pflag.FlagSet) ([]s } } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -840,13 +447,16 @@ func (_m *FlagSetInterface) GetStringStatus(flagSet *pflag.FlagSet) (string, err ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -861,13 +471,16 @@ func (_m *FlagSetInterface) GetStringTo(flagSet *pflag.FlagSet) (string, error) ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -882,13 +495,16 @@ func (_m *FlagSetInterface) GetStringUrl(flagSet *pflag.FlagSet) (string, error) ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -903,13 +519,16 @@ func (_m *FlagSetInterface) GetStringValue(flagSet *pflag.FlagSet) (string, erro ret := _m.Called(flagSet) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (string, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) string); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -924,13 +543,16 @@ func (_m *FlagSetInterface) GetUint16CollectionId(flagSet *pflag.FlagSet) (uint1 ret := _m.Called(flagSet) var r0 uint16 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint16, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint16); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint16) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -945,13 +567,16 @@ func (_m *FlagSetInterface) GetUint16JobId(flagSet *pflag.FlagSet) (uint16, erro ret := _m.Called(flagSet) var r0 uint16 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint16, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint16); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint16) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -966,13 +591,16 @@ func (_m *FlagSetInterface) GetUint32Aggregation(flagSet *pflag.FlagSet) (uint32 ret := _m.Called(flagSet) var r0 uint32 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint32, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint32); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint32) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -987,13 +615,16 @@ func (_m *FlagSetInterface) GetUint32BountyId(flagSet *pflag.FlagSet) (uint32, e ret := _m.Called(flagSet) var r0 uint32 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint32, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint32); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint32) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1008,13 +639,16 @@ func (_m *FlagSetInterface) GetUint32StakerId(flagSet *pflag.FlagSet) (uint32, e ret := _m.Called(flagSet) var r0 uint32 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint32, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint32); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint32) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1029,34 +663,16 @@ func (_m *FlagSetInterface) GetUint32Tolerance(flagSet *pflag.FlagSet) (uint32, ret := _m.Called(flagSet) var r0 uint32 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint32); ok { - r0 = rf(flagSet) - } else { - r0 = ret.Get(0).(uint32) - } - var r1 error - if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { - r1 = rf(flagSet) - } else { - r1 = ret.Error(1) + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint32, error)); ok { + return rf(flagSet) } - - return r0, r1 -} - -// GetUint64GasLimitOverride provides a mock function with given fields: flagSet -func (_m *FlagSetInterface) GetUint64GasLimitOverride(flagSet *pflag.FlagSet) (uint64, error) { - ret := _m.Called(flagSet) - - var r0 uint64 - if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint64); ok { + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint32); ok { r0 = rf(flagSet) } else { - r0 = ret.Get(0).(uint64) + r0 = ret.Get(0).(uint32) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1071,13 +687,16 @@ func (_m *FlagSetInterface) GetUint8Commission(flagSet *pflag.FlagSet) (uint8, e ret := _m.Called(flagSet) var r0 uint8 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint8, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint8); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint8) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1092,13 +711,16 @@ func (_m *FlagSetInterface) GetUint8SelectorType(flagSet *pflag.FlagSet) (uint8, ret := _m.Called(flagSet) var r0 uint8 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint8, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint8); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint8) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1113,13 +735,16 @@ func (_m *FlagSetInterface) GetUint8Weight(flagSet *pflag.FlagSet) (uint8, error ret := _m.Called(flagSet) var r0 uint8 + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) (uint8, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) uint8); ok { r0 = rf(flagSet) } else { r0 = ret.Get(0).(uint8) } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1134,6 +759,10 @@ func (_m *FlagSetInterface) GetUintSliceJobIds(flagSet *pflag.FlagSet) ([]uint, ret := _m.Called(flagSet) var r0 []uint + var r1 error + if rf, ok := ret.Get(0).(func(*pflag.FlagSet) ([]uint, error)); ok { + return rf(flagSet) + } if rf, ok := ret.Get(0).(func(*pflag.FlagSet) []uint); ok { r0 = rf(flagSet) } else { @@ -1142,7 +771,6 @@ func (_m *FlagSetInterface) GetUintSliceJobIds(flagSet *pflag.FlagSet) ([]uint, } } - var r1 error if rf, ok := ret.Get(1).(func(*pflag.FlagSet) error); ok { r1 = rf(flagSet) } else { @@ -1152,17 +780,16 @@ func (_m *FlagSetInterface) GetUintSliceJobIds(flagSet *pflag.FlagSet) ([]uint, return r0, r1 } -type mockConstructorTestingTNewFlagSetInterface interface { +// NewFlagSetInterface creates a new instance of FlagSetInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFlagSetInterface(t interface { mock.TestingT Cleanup(func()) -} - -// NewFlagSetInterface creates a new instance of FlagSetInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewFlagSetInterface(t mockConstructorTestingTNewFlagSetInterface) *FlagSetInterface { +}) *FlagSetInterface { mock := &FlagSetInterface{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} +} \ No newline at end of file diff --git a/cmd/mocks/utils_cmd_interface.go b/cmd/mocks/utils_cmd_interface.go index 9b8b794a0..b957be0e5 100644 --- a/cmd/mocks/utils_cmd_interface.go +++ b/cmd/mocks/utils_cmd_interface.go @@ -11,6 +11,8 @@ import ( bindings "razor/pkg/bindings" + cache "razor/cache" + common "github.com/ethereum/go-ethereum/common" context "context" @@ -109,6 +111,32 @@ func (_m *UtilsCmdInterface) AssignAmountInWei(flagSet *pflag.FlagSet) (*big.Int return r0, r1 } +// BatchGetStakeSnapshotCalls provides a mock function with given fields: client, epoch, numberOfStakers +func (_m *UtilsCmdInterface) BatchGetStakeSnapshotCalls(client *ethclient.Client, epoch uint32, numberOfStakers uint32) ([]*big.Int, error) { + ret := _m.Called(client, epoch, numberOfStakers) + + var r0 []*big.Int + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, uint32) ([]*big.Int, error)); ok { + return rf(client, epoch, numberOfStakers) + } + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, uint32) []*big.Int); ok { + r0 = rf(client, epoch, numberOfStakers) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32, uint32) error); ok { + r1 = rf(client, epoch, numberOfStakers) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CalculateSecret provides a mock function with given fields: account, epoch, keystorePath, chainId func (_m *UtilsCmdInterface) CalculateSecret(account types.Account, epoch uint32, keystorePath string, chainId *big.Int) ([]byte, []byte, error) { ret := _m.Called(account, epoch, keystorePath, chainId) @@ -270,25 +298,25 @@ func (_m *UtilsCmdInterface) ClaimCommission(flagSet *pflag.FlagSet) { _m.Called(flagSet) } -// Commit provides a mock function with given fields: client, config, account, epoch, seed, values -func (_m *UtilsCmdInterface) Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, seed []byte, values []*big.Int) (common.Hash, error) { - ret := _m.Called(client, config, account, epoch, seed, values) +// Commit provides a mock function with given fields: client, config, account, epoch, latestHeader, seed, values +func (_m *UtilsCmdInterface) Commit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *coretypes.Header, seed []byte, values []*big.Int) (common.Hash, error) { + ret := _m.Called(client, config, account, epoch, latestHeader, seed, values) var r0 common.Hash var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, []byte, []*big.Int) (common.Hash, error)); ok { - return rf(client, config, account, epoch, seed, values) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, []byte, []*big.Int) (common.Hash, error)); ok { + return rf(client, config, account, epoch, latestHeader, seed, values) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, []byte, []*big.Int) common.Hash); ok { - r0 = rf(client, config, account, epoch, seed, values) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, []byte, []*big.Int) common.Hash); ok { + r0 = rf(client, config, account, epoch, latestHeader, seed, values) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(common.Hash) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, types.Configurations, types.Account, uint32, []byte, []*big.Int) error); ok { - r1 = rf(client, config, account, epoch, seed, values) + if rf, ok := ret.Get(1).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, []byte, []*big.Int) error); ok { + r1 = rf(client, config, account, epoch, latestHeader, seed, values) } else { r1 = ret.Error(1) } @@ -570,32 +598,32 @@ func (_m *UtilsCmdInterface) GetAlternateProvider() (string, error) { return r0, r1 } -// GetBiggestStakeAndId provides a mock function with given fields: client, address, epoch -func (_m *UtilsCmdInterface) GetBiggestStakeAndId(client *ethclient.Client, address string, epoch uint32) (*big.Int, uint32, error) { - ret := _m.Called(client, address, epoch) +// GetBiggestStakeAndId provides a mock function with given fields: client, epoch +func (_m *UtilsCmdInterface) GetBiggestStakeAndId(client *ethclient.Client, epoch uint32) (*big.Int, uint32, error) { + ret := _m.Called(client, epoch) var r0 *big.Int var r1 uint32 var r2 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, string, uint32) (*big.Int, uint32, error)); ok { - return rf(client, address, epoch) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32) (*big.Int, uint32, error)); ok { + return rf(client, epoch) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, string, uint32) *big.Int); ok { - r0 = rf(client, address, epoch) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32) *big.Int); ok { + r0 = rf(client, epoch) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*big.Int) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, string, uint32) uint32); ok { - r1 = rf(client, address, epoch) + if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32) uint32); ok { + r1 = rf(client, epoch) } else { r1 = ret.Get(1).(uint32) } - if rf, ok := ret.Get(2).(func(*ethclient.Client, string, uint32) error); ok { - r2 = rf(client, address, epoch) + if rf, ok := ret.Get(2).(func(*ethclient.Client, uint32) error); ok { + r2 = rf(client, epoch) } else { r2 = ret.Error(2) } @@ -1189,9 +1217,9 @@ func (_m *UtilsCmdInterface) GiveSorted(client *ethclient.Client, blockManager * return r0 } -// HandleBlock provides a mock function with given fields: client, account, blockNumber, config, rogueData, backupNodeActionsToIgnore -func (_m *UtilsCmdInterface) HandleBlock(client *ethclient.Client, account types.Account, blockNumber *big.Int, config types.Configurations, rogueData types.Rogue, backupNodeActionsToIgnore []string) { - _m.Called(client, account, blockNumber, config, rogueData, backupNodeActionsToIgnore) +// HandleBlock provides a mock function with given fields: client, account, stakerId, header, config, commitParams, rogueData, backupNodeActionsToIgnore +func (_m *UtilsCmdInterface) HandleBlock(client *ethclient.Client, account types.Account, stakerId uint32, header *coretypes.Header, config types.Configurations, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) { + _m.Called(client, account, stakerId, header, config, commitParams, rogueData, backupNodeActionsToIgnore) } // HandleClaimBounty provides a mock function with given fields: client, config, account @@ -1208,23 +1236,23 @@ func (_m *UtilsCmdInterface) HandleClaimBounty(client *ethclient.Client, config return r0 } -// HandleCommitState provides a mock function with given fields: client, epoch, seed, rogueData -func (_m *UtilsCmdInterface) HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, rogueData types.Rogue) (types.CommitData, error) { - ret := _m.Called(client, epoch, seed, rogueData) +// HandleCommitState provides a mock function with given fields: client, epoch, seed, commitParams, rogueData +func (_m *UtilsCmdInterface) HandleCommitState(client *ethclient.Client, epoch uint32, seed []byte, commitParams *types.CommitParams, rogueData types.Rogue) (types.CommitData, error) { + ret := _m.Called(client, epoch, seed, commitParams, rogueData) var r0 types.CommitData var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, []byte, types.Rogue) (types.CommitData, error)); ok { - return rf(client, epoch, seed, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, []byte, *types.CommitParams, types.Rogue) (types.CommitData, error)); ok { + return rf(client, epoch, seed, commitParams, rogueData) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, []byte, types.Rogue) types.CommitData); ok { - r0 = rf(client, epoch, seed, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, []byte, *types.CommitParams, types.Rogue) types.CommitData); ok { + r0 = rf(client, epoch, seed, commitParams, rogueData) } else { r0 = ret.Get(0).(types.CommitData) } - if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32, []byte, types.Rogue) error); ok { - r1 = rf(client, epoch, seed, rogueData) + if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32, []byte, *types.CommitParams, types.Rogue) error); ok { + r1 = rf(client, epoch, seed, commitParams, rogueData) } else { r1 = ret.Error(1) } @@ -1353,13 +1381,57 @@ func (_m *UtilsCmdInterface) IndexRevealEventsOfCurrentEpoch(client *ethclient.C return r0, r1 } -// InitiateCommit provides a mock function with given fields: client, config, account, epoch, stakerId, rogueData -func (_m *UtilsCmdInterface) InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, rogueData types.Rogue) error { - ret := _m.Called(client, config, account, epoch, stakerId, rogueData) +// InitJobAndCollectionCache provides a mock function with given fields: client +func (_m *UtilsCmdInterface) InitJobAndCollectionCache(client *ethclient.Client) (*cache.JobsCache, *cache.CollectionsCache, *big.Int, error) { + ret := _m.Called(client) + + var r0 *cache.JobsCache + var r1 *cache.CollectionsCache + var r2 *big.Int + var r3 error + if rf, ok := ret.Get(0).(func(*ethclient.Client) (*cache.JobsCache, *cache.CollectionsCache, *big.Int, error)); ok { + return rf(client) + } + if rf, ok := ret.Get(0).(func(*ethclient.Client) *cache.JobsCache); ok { + r0 = rf(client) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cache.JobsCache) + } + } + + if rf, ok := ret.Get(1).(func(*ethclient.Client) *cache.CollectionsCache); ok { + r1 = rf(client) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*cache.CollectionsCache) + } + } + + if rf, ok := ret.Get(2).(func(*ethclient.Client) *big.Int); ok { + r2 = rf(client) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(*big.Int) + } + } + + if rf, ok := ret.Get(3).(func(*ethclient.Client) error); ok { + r3 = rf(client) + } else { + r3 = ret.Error(3) + } + + return r0, r1, r2, r3 +} + +// InitiateCommit provides a mock function with given fields: client, config, account, epoch, stakerId, latestHeader, commitParams, rogueData +func (_m *UtilsCmdInterface) InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, latestHeader *coretypes.Header, commitParams *types.CommitParams, rogueData types.Rogue) error { + ret := _m.Called(client, config, account, epoch, stakerId, latestHeader, commitParams, rogueData) var r0 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, uint32, types.Rogue) error); ok { - r0 = rf(client, config, account, epoch, stakerId, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, uint32, *coretypes.Header, *types.CommitParams, types.Rogue) error); ok { + r0 = rf(client, config, account, epoch, stakerId, latestHeader, commitParams, rogueData) } else { r0 = ret.Error(0) } @@ -1367,13 +1439,13 @@ func (_m *UtilsCmdInterface) InitiateCommit(client *ethclient.Client, config typ return r0 } -// InitiatePropose provides a mock function with given fields: client, config, account, epoch, staker, blockNumber, rogueData -func (_m *UtilsCmdInterface) InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, blockNumber *big.Int, rogueData types.Rogue) error { - ret := _m.Called(client, config, account, epoch, staker, blockNumber, rogueData) +// InitiatePropose provides a mock function with given fields: client, config, account, epoch, staker, latestHeader, rogueData +func (_m *UtilsCmdInterface) InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *coretypes.Header, rogueData types.Rogue) error { + ret := _m.Called(client, config, account, epoch, staker, latestHeader, rogueData) var r0 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, bindings.StructsStaker, *big.Int, types.Rogue) error); ok { - r0 = rf(client, config, account, epoch, staker, blockNumber, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, bindings.StructsStaker, *coretypes.Header, types.Rogue) error); ok { + r0 = rf(client, config, account, epoch, staker, latestHeader, rogueData) } else { r0 = ret.Error(0) } @@ -1381,13 +1453,13 @@ func (_m *UtilsCmdInterface) InitiatePropose(client *ethclient.Client, config ty return r0 } -// InitiateReveal provides a mock function with given fields: client, config, account, epoch, staker, rogueData -func (_m *UtilsCmdInterface) InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, rogueData types.Rogue) error { - ret := _m.Called(client, config, account, epoch, staker, rogueData) +// InitiateReveal provides a mock function with given fields: client, config, account, epoch, staker, latestHeader, rogueData +func (_m *UtilsCmdInterface) InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *coretypes.Header, rogueData types.Rogue) error { + ret := _m.Called(client, config, account, epoch, staker, latestHeader, rogueData) var r0 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, bindings.StructsStaker, types.Rogue) error); ok { - r0 = rf(client, config, account, epoch, staker, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, bindings.StructsStaker, *coretypes.Header, types.Rogue) error); ok { + r0 = rf(client, config, account, epoch, staker, latestHeader, rogueData) } else { r0 = ret.Error(0) } @@ -1531,13 +1603,13 @@ func (_m *UtilsCmdInterface) ModifyCollectionStatus(client *ethclient.Client, co return r0, r1 } -// Propose provides a mock function with given fields: client, config, account, staker, epoch, blockNumber, rogueData -func (_m *UtilsCmdInterface) Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, blockNumber *big.Int, rogueData types.Rogue) error { - ret := _m.Called(client, config, account, staker, epoch, blockNumber, rogueData) +// Propose provides a mock function with given fields: client, config, account, staker, epoch, latestHeader, rogueData +func (_m *UtilsCmdInterface) Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, latestHeader *coretypes.Header, rogueData types.Rogue) error { + ret := _m.Called(client, config, account, staker, epoch, latestHeader, rogueData) var r0 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, bindings.StructsStaker, uint32, *big.Int, types.Rogue) error); ok { - r0 = rf(client, config, account, staker, epoch, blockNumber, rogueData) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, bindings.StructsStaker, uint32, *coretypes.Header, types.Rogue) error); ok { + r0 = rf(client, config, account, staker, epoch, latestHeader, rogueData) } else { r0 = ret.Error(0) } @@ -1576,25 +1648,25 @@ func (_m *UtilsCmdInterface) ResetUnstakeLock(client *ethclient.Client, config t return r0, r1 } -// Reveal provides a mock function with given fields: client, config, account, epoch, commitData, signature -func (_m *UtilsCmdInterface) Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, commitData types.CommitData, signature []byte) (common.Hash, error) { - ret := _m.Called(client, config, account, epoch, commitData, signature) +// Reveal provides a mock function with given fields: client, config, account, epoch, latestHeader, commitData, signature +func (_m *UtilsCmdInterface) Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *coretypes.Header, commitData types.CommitData, signature []byte) (common.Hash, error) { + ret := _m.Called(client, config, account, epoch, latestHeader, commitData, signature) var r0 common.Hash var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, types.CommitData, []byte) (common.Hash, error)); ok { - return rf(client, config, account, epoch, commitData, signature) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, types.CommitData, []byte) (common.Hash, error)); ok { + return rf(client, config, account, epoch, latestHeader, commitData, signature) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, types.CommitData, []byte) common.Hash); ok { - r0 = rf(client, config, account, epoch, commitData, signature) + if rf, ok := ret.Get(0).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, types.CommitData, []byte) common.Hash); ok { + r0 = rf(client, config, account, epoch, latestHeader, commitData, signature) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(common.Hash) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, types.Configurations, types.Account, uint32, types.CommitData, []byte) error); ok { - r1 = rf(client, config, account, epoch, commitData, signature) + if rf, ok := ret.Get(1).(func(*ethclient.Client, types.Configurations, types.Account, uint32, *coretypes.Header, types.CommitData, []byte) error); ok { + r1 = rf(client, config, account, epoch, latestHeader, commitData, signature) } else { r1 = ret.Error(1) } @@ -1826,13 +1898,13 @@ func (_m *UtilsCmdInterface) UpdateJob(client *ethclient.Client, config types.Co return r0, r1 } -// Vote provides a mock function with given fields: ctx, config, client, rogueData, account, backupNodeActionsToIgnore -func (_m *UtilsCmdInterface) Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, rogueData types.Rogue, account types.Account, backupNodeActionsToIgnore []string) error { - ret := _m.Called(ctx, config, client, rogueData, account, backupNodeActionsToIgnore) +// Vote provides a mock function with given fields: ctx, config, client, account, stakerId, commitParams, rogueData, backupNodeActionsToIgnore +func (_m *UtilsCmdInterface) Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, account types.Account, stakerId uint32, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) error { + ret := _m.Called(ctx, config, client, account, stakerId, commitParams, rogueData, backupNodeActionsToIgnore) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, types.Configurations, *ethclient.Client, types.Rogue, types.Account, []string) error); ok { - r0 = rf(ctx, config, client, rogueData, account, backupNodeActionsToIgnore) + if rf, ok := ret.Get(0).(func(context.Context, types.Configurations, *ethclient.Client, types.Account, uint32, *types.CommitParams, types.Rogue, []string) error); ok { + r0 = rf(ctx, config, client, account, stakerId, commitParams, rogueData, backupNodeActionsToIgnore) } else { r0 = ret.Error(0) } diff --git a/cmd/modifyCollectionStatus.go b/cmd/modifyCollectionStatus.go index 6e968d2cd..7f16efad6 100644 --- a/cmd/modifyCollectionStatus.go +++ b/cmd/modifyCollectionStatus.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -48,7 +49,12 @@ func (*UtilsStruct) ExecuteModifyCollectionStatus(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) collectionId, err := flagSetUtils.GetUint16CollectionId(flagSet) @@ -61,13 +67,11 @@ func (*UtilsStruct) ExecuteModifyCollectionStatus(flagSet *pflag.FlagSet) { utils.CheckError("Error in parsing status: ", err) modifyCollectionInput := types.ModifyCollectionInput{ - Address: address, - Password: password, Status: status, CollectionId: collectionId, + Account: account, } - log.Debugf("Calling ModifyCollectionStatus() with arguments modifyCollectionInput = %+v", modifyCollectionInput) txn, err := cmdUtils.ModifyCollectionStatus(client, config, modifyCollectionInput) utils.CheckError("Error in changing collection active status: ", err) if txn != core.NilHash { @@ -101,14 +105,13 @@ func (*UtilsStruct) ModifyCollectionStatus(client *ethclient.Client, config type txnArgs := types.TransactionOptions{ Client: client, - Password: modifyCollectionInput.Password, - AccountAddress: modifyCollectionInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.CollectionManagerAddress, MethodName: "setCollectionStatus", Parameters: []interface{}{modifyCollectionInput.Status, modifyCollectionInput.CollectionId}, ABI: bindings.CollectionManagerMetaData.ABI, + Account: modifyCollectionInput.Account, } txnOpts := razorUtils.GetTxnOpts(txnArgs) diff --git a/cmd/modifyCollectionStatus_test.go b/cmd/modifyCollectionStatus_test.go index f0f157d89..f6fb0d750 100644 --- a/cmd/modifyCollectionStatus_test.go +++ b/cmd/modifyCollectionStatus_test.go @@ -1,11 +1,9 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" @@ -81,10 +79,6 @@ func TestCheckCurrentStatus(t *testing.T) { } func TestModifyAssetStatus(t *testing.T) { - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31337)) - var config types.Configurations var client *ethclient.Client @@ -94,7 +88,6 @@ func TestModifyAssetStatus(t *testing.T) { currentStatusErr error epoch uint32 epochErr error - txnOpts *bind.TransactOpts SetCollectionStatus *Types.Transaction SetAssetStatusErr error hash common.Hash @@ -110,7 +103,6 @@ func TestModifyAssetStatus(t *testing.T) { args: args{ status: true, currentStatus: false, - txnOpts: txnOpts, SetCollectionStatus: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -122,7 +114,6 @@ func TestModifyAssetStatus(t *testing.T) { args: args{ status: true, currentStatusErr: errors.New("current status error"), - txnOpts: txnOpts, SetCollectionStatus: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -134,7 +125,6 @@ func TestModifyAssetStatus(t *testing.T) { args: args{ status: true, currentStatus: true, - txnOpts: txnOpts, SetCollectionStatus: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -146,7 +136,6 @@ func TestModifyAssetStatus(t *testing.T) { args: args{ status: true, currentStatus: false, - txnOpts: txnOpts, SetAssetStatusErr: errors.New("SetAssetStatus error"), hash: common.BigToHash(big.NewInt(1)), }, @@ -158,7 +147,6 @@ func TestModifyAssetStatus(t *testing.T) { args: args{ status: true, currentStatus: false, - txnOpts: txnOpts, epochErr: errors.New("WaitForAppropriateState error"), SetCollectionStatus: &Types.Transaction{}, SetAssetStatusErr: nil, @@ -173,7 +161,7 @@ func TestModifyAssetStatus(t *testing.T) { SetUpMockInterfaces() cmdUtilsMock.On("CheckCurrentStatus", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint16")).Return(tt.args.currentStatus, tt.args.currentStatusErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("WaitForAppropriateState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.Anything).Return(tt.args.epoch, tt.args.epochErr) assetManagerMock.On("SetCollectionStatus", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.SetCollectionStatus, tt.args.SetAssetStatusErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -334,7 +322,8 @@ func TestExecuteModifyAssetStatus(t *testing.T) { flagSetMock.On("GetUint16CollectionId", flagSet).Return(tt.args.collectionId, tt.args.collectionIdErr) flagSetMock.On("GetStringStatus", flagSet).Return(tt.args.status, tt.args.statusErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) stringMock.On("ParseBool", mock.AnythingOfType("string")).Return(tt.args.parseStatus, tt.args.parseStatusErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) cmdUtilsMock.On("ModifyCollectionStatus", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.ModifyCollectionStatusHash, tt.args.ModifyCollectionStatusErr) diff --git a/cmd/propose.go b/cmd/propose.go index 3a66e4f09..70030c819 100644 --- a/cmd/propose.go +++ b/cmd/propose.go @@ -11,8 +11,11 @@ import ( "razor/pkg/bindings" "razor/utils" "sort" + "strings" + "sync" "time" + Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" solsha3 "github.com/miguelmota/go-solidity-sha3" ) @@ -28,8 +31,8 @@ var globalProposedDataStruct types.ProposeFileData // Find iteration using salt as seed //This functions handles the propose state -func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, blockNumber *big.Int, rogueData types.Rogue) error { - if state, err := razorUtils.GetBufferedState(client, config.BufferPercent); err != nil || state != 2 { +func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configurations, account types.Account, staker bindings.StructsStaker, epoch uint32, latestHeader *Types.Header, rogueData types.Rogue) error { + if state, err := razorUtils.GetBufferedState(client, latestHeader, config.BufferPercent); err != nil || state != 2 { log.Error("Not propose state") return err } @@ -59,7 +62,7 @@ func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configuration biggestStakerId = smallestStakerId log.Debugf("Propose: In rogue mode, Biggest Stake: %s, Biggest Staker Id: %d", biggestStake, biggestStakerId) } else { - biggestStake, biggestStakerId, biggestStakerErr = cmdUtils.GetBiggestStakeAndId(client, account.Address, epoch) + biggestStake, biggestStakerId, biggestStakerErr = cmdUtils.GetBiggestStakeAndId(client, epoch) if biggestStakerErr != nil { log.Error("Error in calculating biggest staker: ", biggestStakerErr) return biggestStakerErr @@ -131,8 +134,8 @@ func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configuration } log.Info("Current iteration is less than iteration of last proposed block, can propose") } - log.Debugf("Propose: Calling MakeBlock() with arguments blockNumber = %s, epoch = %d, rogueData = %+v", blockNumber, epoch, rogueData) - medians, ids, revealedDataMaps, err := cmdUtils.MakeBlock(client, blockNumber, epoch, rogueData) + log.Debugf("Propose: Calling MakeBlock() with arguments blockNumber = %s, epoch = %d, rogueData = %+v", latestHeader.Number, epoch, rogueData) + medians, ids, revealedDataMaps, err := cmdUtils.MakeBlock(client, latestHeader.Number, epoch, rogueData) if err != nil { log.Error(err) return err @@ -144,14 +147,13 @@ func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configuration txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.BlockManagerAddress, ABI: bindings.BlockManagerMetaData.ABI, MethodName: "propose", Parameters: []interface{}{epoch, ids, medians, big.NewInt(int64(iteration)), biggestStakerId}, + Account: account, }) log.Debugf("Executing Propose transaction with epoch = %d, Ids = %v, medians = %s, iteration = %s, biggestStakerId = %d", epoch, ids, medians, big.NewInt(int64(iteration)), biggestStakerId) @@ -198,7 +200,7 @@ func (*UtilsStruct) Propose(client *ethclient.Client, config types.Configuration } //This function returns the biggest stake and Id of it -func (*UtilsStruct) GetBiggestStakeAndId(client *ethclient.Client, address string, epoch uint32) (*big.Int, uint32, error) { +func (*UtilsStruct) GetBiggestStakeAndId(client *ethclient.Client, epoch uint32) (*big.Int, uint32, error) { numberOfStakers, err := razorUtils.GetNumberOfStakers(client) if err != nil { return nil, 0, err @@ -210,38 +212,19 @@ func (*UtilsStruct) GetBiggestStakeAndId(client *ethclient.Client, address strin var biggestStakerId uint32 biggestStake := big.NewInt(0) - bufferPercent, err := cmdUtils.GetBufferPercent() + stakeSnapshotArray, err := cmdUtils.BatchGetStakeSnapshotCalls(client, epoch, numberOfStakers) if err != nil { return nil, 0, err } - log.Debug("GetBiggestStakeAndId: Buffer Percent: ", bufferPercent) - - stateRemainingTime, err := razorUtils.GetRemainingTimeOfCurrentState(client, bufferPercent) - if err != nil { - return nil, 0, err - } - log.Debug("GetBiggestStakeAndId: State remaining time: ", stateRemainingTime) - stateTimeout := time.NewTimer(time.Second * time.Duration(stateRemainingTime)) + log.Debugf("Stake Snapshot Array: %+v", stakeSnapshotArray) log.Debug("Iterating over all the stakers...") -loop: - for i := 1; i <= int(numberOfStakers); i++ { - select { - case <-stateTimeout.C: - log.Error("State timeout!") - err = errors.New("state timeout error") - break loop - default: - log.Debug("Propose: Staker Id: ", i) - stake, err := razorUtils.GetStakeSnapshot(client, uint32(i), epoch) - if err != nil { - return nil, 0, err - } - log.Debugf("Stake Snapshot of staker having stakerId %d is %s", i, stake) - if stake.Cmp(biggestStake) > 0 { - biggestStake = stake - biggestStakerId = uint32(i) - } + for i := 0; i < len(stakeSnapshotArray); i++ { + stake := stakeSnapshotArray[i] + log.Debugf("Stake Snapshot of staker having stakerId %d is %s", i+1, stake) + if stake.Cmp(biggestStake) > 0 { + biggestStake = stake + biggestStakerId = uint32(i + 1) } } if err != nil { @@ -252,7 +235,6 @@ loop: return biggestStake, biggestStakerId, nil } -//This function returns the iteration of the proposer if he is elected func (*UtilsStruct) GetIteration(client *ethclient.Client, proposer types.ElectedProposer, bufferPercent int32) int { stake, err := razorUtils.GetStakeSnapshot(client, proposer.StakerId, proposer.Epoch) if err != nil { @@ -265,31 +247,75 @@ func (*UtilsStruct) GetIteration(client *ethclient.Client, proposer types.Electe if err != nil { return -1 } - stateTimeout := time.NewTimer(time.Second * time.Duration(stateRemainingTime)) log.Debug("GetIteration: State remaining time: ", stateRemainingTime) + + stateTimeout := time.NewTimer(time.Second * time.Duration(stateRemainingTime)) + wg := &sync.WaitGroup{} + wg.Add(core.NumRoutines) + done := make(chan bool, 10) + iterationResult := make(chan int, 10) + quit := make(chan bool, 10) + log.Debug("Calculating Iteration...") - log.Debugf("GetIteration: Calling IsElectedProposer() to find iteration...") -loop: - for i := 0; i < 10000000; i++ { - select { - case <-stateTimeout.C: - log.Error("State timeout!") - break loop - default: - proposer.Iteration = i - isElected := cmdUtils.IsElectedProposer(proposer, currentStakerStake) - if isElected { - return i + for routine := 0; routine < core.NumRoutines; routine++ { + go getIterationConcurrently(proposer, currentStakerStake, routine, wg, done, iterationResult, quit, stateTimeout) + } + + log.Debug("Waiting for all the goroutines to finish") + wg.Wait() + log.Debug("Done") + + close(done) + close(quit) + close(iterationResult) + + var iterations []int + + for iteration := range iterationResult { + iterations = append(iterations, iteration) + } + + sort.Ints(iterations) + return iterations[0] +} + +func getIterationConcurrently(proposer types.ElectedProposer, currentStake *big.Int, routine int, wg *sync.WaitGroup, done chan bool, iterationResult chan int, quit chan bool, stateTimeout *time.Timer) { + //PARALLEL IMPLEMENTATION WITH BATCHES + + defer wg.Done() + batchSize := core.BatchSize //1000 + NumBatches := core.MaxIterations / batchSize //10000000/1000 = 10000 + // Batch 0th - [0,1000) + // Batch 1th - [1000,2000) + for batch := 0; batch < NumBatches; batch++ { + for iteration := (batch * batchSize) + routine; iteration < (batch*batchSize)+batchSize; iteration = iteration + core.NumRoutines { + select { + case <-stateTimeout.C: + log.Error("getIterationConcurrently: State timeout!") + iterationResult <- -1 + quit <- true + return + default: + proposer.Iteration = iteration + if len(done) >= 1 || len(quit) >= 1 { + return + } + isElected := cmdUtils.IsElectedProposer(proposer, currentStake) + if isElected { + iterationResult <- iteration + done <- true + return + } } } } - return -1 + iterationResult <- -1 + log.Debug("IsElected is never true for this batch") } //This function returns if the elected staker is proposer or not func (*UtilsStruct) IsElectedProposer(proposer types.ElectedProposer, currentStakerStake *big.Int) bool { seed := solsha3.SoliditySHA3([]string{"uint256"}, []interface{}{big.NewInt(int64(proposer.Iteration))}) - pseudoRandomNumber := pseudoRandomNumberGenerator(seed, proposer.NumberOfStakers, proposer.Salt[:]) //add +1 since prng returns 0 to max-1 and staker start from 1 pseudoRandomNumber = pseudoRandomNumber.Add(pseudoRandomNumber, big.NewInt(1)) @@ -450,6 +476,32 @@ func (*UtilsStruct) GetSmallestStakeAndId(client *ethclient.Client, epoch uint32 return smallestStake, smallestStakerId, nil } +func (*UtilsStruct) BatchGetStakeSnapshotCalls(client *ethclient.Client, epoch uint32, numberOfStakers uint32) ([]*big.Int, error) { + voteManagerABI, err := utils.ABIInterface.Parse(strings.NewReader(bindings.VoteManagerMetaData.ABI)) + if err != nil { + log.Errorf("Error in parsed voteManager ABI: %v", err) + return nil, err + } + + args := make([][]interface{}, numberOfStakers) + for i := uint32(1); i <= numberOfStakers; i++ { + args[i-1] = []interface{}{epoch, i} + } + + results, err := clientUtils.BatchCall(client, &voteManagerABI, core.VoteManagerAddress, core.GetStakeSnapshotMethod, args) + if err != nil { + log.Error("Error in performing getStakeSnapshot batch calls: ", err) + return nil, err + } + + var stakeArray []*big.Int + for _, result := range results { + stakeArray = append(stakeArray, result[0].(*big.Int)) + } + + return stakeArray, nil +} + func updateGlobalProposedDataStruct(proposedData types.ProposeFileData) types.ProposeFileData { globalProposedDataStruct.MediansData = proposedData.MediansData globalProposedDataStruct.RevealedDataMaps = proposedData.RevealedDataMaps diff --git a/cmd/propose_test.go b/cmd/propose_test.go index 8e52c748e..c18379085 100644 --- a/cmd/propose_test.go +++ b/cmd/propose_test.go @@ -1,20 +1,20 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" "fmt" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/stretchr/testify/assert" "math/big" "razor/core/types" "razor/pkg/bindings" + utilsPkgMocks "razor/utils/mocks" "reflect" + "strings" "testing" "github.com/stretchr/testify/mock" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -23,21 +23,20 @@ import ( func TestPropose(t *testing.T) { var ( - client *ethclient.Client - account types.Account - config types.Configurations - staker bindings.StructsStaker - epoch uint32 - blockNumber *big.Int + client *ethclient.Client + account types.Account + config types.Configurations + staker bindings.StructsStaker + epoch uint32 ) - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - salt := []byte{142, 170, 157, 83, 109, 43, 34, 152, 21, 154, 159, 12, 195, 119, 50, 186, 218, 57, 39, 173, 228, 135, 20, 100, 149, 27, 169, 158, 34, 113, 66, 64} saltBytes32 := [32]byte{} copy(saltBytes32[:], salt) + latestHeader := &Types.Header{ + Number: big.NewInt(1001), + } type args struct { rogueData types.Rogue state int64 @@ -75,7 +74,6 @@ func TestPropose(t *testing.T) { fileNameErr error saveDataErr error mediansBigInt []*big.Int - txnOpts *bind.TransactOpts proposeTxn *Types.Transaction proposeErr error hash common.Hash @@ -102,7 +100,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -124,7 +121,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -146,7 +142,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -167,7 +162,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -189,7 +183,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -211,7 +204,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -233,7 +225,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -255,7 +246,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -277,7 +267,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStructErr: errors.New("lastProposedBlockStruct error"), medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -301,7 +290,6 @@ func TestPropose(t *testing.T) { Iteration: big.NewInt(1), }, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -325,7 +313,6 @@ func TestPropose(t *testing.T) { Iteration: big.NewInt(2), }, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -350,7 +337,6 @@ func TestPropose(t *testing.T) { Iteration: big.NewInt(2), }, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -372,7 +358,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, mediansErr: errors.New("makeBlock error"), - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -394,7 +379,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeErr: errors.New("propose error"), hash: common.BigToHash(big.NewInt(1)), }, @@ -416,7 +400,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), fileNameErr: errors.New("fileName error"), @@ -439,7 +422,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), saveDataErr: errors.New("error in saving data"), @@ -480,7 +462,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -507,7 +488,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -529,7 +509,6 @@ func TestPropose(t *testing.T) { lastIteration: big.NewInt(5), lastProposedBlockStruct: bindings.StructsBlock{}, medians: []*big.Int{big.NewInt(6701548), big.NewInt(478307)}, - txnOpts: txnOpts, proposeTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), waitForBlockCompletionErr: errors.New("waitForBlockCompletion error"), @@ -540,9 +519,9 @@ func TestPropose(t *testing.T) { for _, tt := range tests { SetUpMockInterfaces() - utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("int32")).Return(tt.args.state, tt.args.stateErr) + utilsMock.On("GetBufferedState", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) utilsMock.On("GetNumberOfStakers", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.numStakers, tt.args.numStakerErr) - cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32")).Return(tt.args.biggestStake, tt.args.biggestStakerId, tt.args.biggestStakerIdErr) + cmdUtilsMock.On("GetBiggestStakeAndId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.biggestStake, tt.args.biggestStakerId, tt.args.biggestStakerIdErr) cmdUtilsMock.On("GetSmallestStakeAndId", mock.Anything, mock.Anything).Return(tt.args.smallestStake, tt.args.smallestStakerId, tt.args.smallestStakerIdErr) utilsMock.On("GetRandaoHash", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.randaoHash, tt.args.randaoHashErr) cmdUtilsMock.On("GetIteration", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.iteration) @@ -557,7 +536,7 @@ func TestPropose(t *testing.T) { utilsMock.On("ConvertUint32ArrayToBigIntArray", mock.Anything).Return(tt.args.mediansBigInt) pathMock.On("GetProposeDataFileName", mock.AnythingOfType("string")).Return(tt.args.fileName, tt.args.fileNameErr) fileUtilsMock.On("SaveDataToProposeJsonFile", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.saveDataErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) blockManagerMock.On("Propose", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.proposeTxn, tt.args.proposeErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) cmdUtilsMock.On("GetBufferPercent").Return(tt.args.bufferPercent, tt.args.bufferPercentErr) @@ -565,7 +544,7 @@ func TestPropose(t *testing.T) { utils := &UtilsStruct{} t.Run(tt.name, func(t *testing.T) { - err := utils.Propose(client, config, account, staker, epoch, blockNumber, tt.args.rogueData) + err := utils.Propose(client, config, account, staker, epoch, latestHeader, tt.args.rogueData) if err == nil || tt.wantErr == nil { if err != tt.wantErr { t.Errorf("Error for Propose function, got = %v, want %v", err, tt.wantErr) @@ -581,18 +560,13 @@ func TestPropose(t *testing.T) { func TestGetBiggestStakeAndId(t *testing.T) { var client *ethclient.Client - var address string var epoch uint32 type args struct { - numOfStakers uint32 - numOfStakersErr error - bufferPercent int32 - bufferPercentErr error - remainingTime int64 - remainingTimeErr error - stake *big.Int - stakeErr error + numOfStakers uint32 + numOfStakersErr error + stakeArray []*big.Int + stakeErr error } tests := []struct { name string @@ -604,12 +578,11 @@ func TestGetBiggestStakeAndId(t *testing.T) { { name: "Test 1: When GetBiggestStakeAndId function executes successfully", args: args{ - numOfStakers: 2, - remainingTime: 10, - stake: big.NewInt(1).Mul(big.NewInt(5326), big.NewInt(1e18)), + numOfStakers: 7, + stakeArray: []*big.Int{big.NewInt(89999), big.NewInt(70000), big.NewInt(72000), big.NewInt(99999), big.NewInt(200030), big.NewInt(67777), big.NewInt(100011)}, }, - wantStake: big.NewInt(1).Mul(big.NewInt(5326), big.NewInt(1e18)), - wantId: 1, + wantStake: big.NewInt(200030), + wantId: 5, wantErr: nil, }, { @@ -625,55 +598,30 @@ func TestGetBiggestStakeAndId(t *testing.T) { name: "Test 3: When there is an error in getting numOfStakers", args: args{ numOfStakersErr: errors.New("numOfStakers error"), - remainingTime: 10, }, wantStake: nil, wantId: 0, wantErr: errors.New("numOfStakers error"), }, { - name: "Test 4: When there is an error in getting stake", - args: args{ - numOfStakers: 5, - remainingTime: 10, - stakeErr: errors.New("stake error"), - }, - wantStake: nil, - wantId: 0, - wantErr: errors.New("stake error"), - }, - { - name: "Test 5: When there is an error in getting remaining time", - args: args{ - numOfStakers: 2, - remainingTime: 10, - remainingTimeErr: errors.New("time error"), - }, - wantStake: nil, - wantId: 0, - wantErr: errors.New("time error"), - }, - { - name: "Test 6: When there is a timeout case", + name: "Test 4: When there is an error in getting stakeArray from batch calls", args: args{ - numOfStakers: 100000, - bufferPercent: 10, - remainingTime: 0, - stake: big.NewInt(1).Mul(big.NewInt(5326), big.NewInt(1e18)), + numOfStakers: 5, + stakeErr: errors.New("batch calls error"), }, wantStake: nil, wantId: 0, - wantErr: errors.New("state timeout error"), + wantErr: errors.New("batch calls error"), }, { - name: "Test 7: When there is an error in getting buffer percent", + name: "Test 5: When there are large number of stakers", args: args{ - numOfStakers: 2, - bufferPercentErr: errors.New("buffer error"), + numOfStakers: 999, + stakeArray: GenerateDummyStakeSnapshotArray(999), }, - wantStake: nil, - wantId: 0, - wantErr: errors.New("buffer error"), + wantStake: big.NewInt(999000), + wantId: 999, + wantErr: nil, }, } for _, tt := range tests { @@ -681,19 +629,11 @@ func TestGetBiggestStakeAndId(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetNumberOfStakers", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.numOfStakers, tt.args.numOfStakersErr) - utilsMock.On("GetStakeSnapshot", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(tt.args.stake, tt.args.stakeErr) - utilsMock.On("GetRemainingTimeOfCurrentState", mock.Anything, mock.Anything).Return(tt.args.remainingTime, tt.args.remainingTimeErr) - cmdUtilsMock.On("GetBufferPercent").Return(tt.args.bufferPercent, tt.args.bufferPercentErr) + cmdUtilsMock.On("BatchGetStakeSnapshotCalls", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(tt.args.stakeArray, tt.args.stakeErr) utils := &UtilsStruct{} - gotStake, gotId, err := utils.GetBiggestStakeAndId(client, address, epoch) - if gotStake.Cmp(tt.wantStake) != 0 { - t.Errorf("Biggest Stake from GetBiggestStakeAndId function, got = %v, want %v", gotStake, tt.wantStake) - } - if gotId != tt.wantId { - t.Errorf("Staker Id of staker having biggest Influence from GetBiggestStakeAndId function, got = %v, want %v", gotId, tt.wantId) - } + gotStake, gotId, err := utils.GetBiggestStakeAndId(client, epoch) if err == nil || tt.wantErr == nil { if err != tt.wantErr { t.Errorf("Error for GetBiggestStakeAndId function, got = %v, want %v", err, tt.wantErr) @@ -703,6 +643,12 @@ func TestGetBiggestStakeAndId(t *testing.T) { t.Errorf("Error for GetBiggestStakeAndId function, got = %v, want %v", err, tt.wantErr) } } + if gotStake.Cmp(tt.wantStake) != 0 { + t.Errorf("Biggest Stake from GetBiggestStakeAndId function, got = %v, want %v", gotStake, tt.wantStake) + } + if gotId != tt.wantId { + t.Errorf("Staker Id of staker having biggest Influence from GetBiggestStakeAndId function, got = %v, want %v", gotId, tt.wantId) + } }) } @@ -715,15 +661,24 @@ func stakeSnapshotValue(stake string) *big.Int { func TestGetIteration(t *testing.T) { var client *ethclient.Client - var proposer types.ElectedProposer var bufferPercent int32 + salt := []byte{142, 170, 157, 83, 109, 43, 34, 152, 21, 154, 159, 12, 195, 119, 50, 186, 218, 57, 39, 173, 228, 135, 20, 100, 149, 27, 169, 158, 34, 113, 66, 64} + saltBytes32 := [32]byte{} + copy(saltBytes32[:], salt) + + proposer := types.ElectedProposer{ + BiggestStake: big.NewInt(1).Mul(big.NewInt(10000000), big.NewInt(1e18)), + StakerId: 2, + NumberOfStakers: 10, + Salt: saltBytes32, + } + type args struct { - stakeSnapshot *big.Int - stakeSnapshotErr error - isElectedProposer bool - remainingTime int64 - remainingTimeErr error + stakeSnapshot *big.Int + stakeSnapshotErr error + remainingTime int64 + remainingTimeErr error } tests := []struct { name string @@ -733,15 +688,15 @@ func TestGetIteration(t *testing.T) { { name: "Test 1: When getIteration returns a valid iteration", args: args{ - stakeSnapshot: stakeSnapshotValue("2592145500000000000000000"), - isElectedProposer: true, - remainingTime: 100, + stakeSnapshot: big.NewInt(1000), + remainingTime: 10, }, - want: 0, + want: 70183, }, { name: "Test 2: When there is an error in getting stakeSnapshotValue", args: args{ + stakeSnapshot: big.NewInt(0), stakeSnapshotErr: errors.New("error in getting stakeSnapshotValue"), }, want: -1, @@ -749,33 +704,32 @@ func TestGetIteration(t *testing.T) { { name: "Test 3: When getIteration returns an invalid iteration", args: args{ - stakeSnapshot: stakeSnapshotValue("2592145500000000000000000"), - isElectedProposer: false, - remainingTime: 2, + stakeSnapshot: big.NewInt(1), + remainingTime: 2, }, want: -1, }, { name: "Test 4: When there is an error in getting remaining time for the state", args: args{ - stakeSnapshot: stakeSnapshotValue("2592145500000000000000000"), - isElectedProposer: true, - remainingTimeErr: errors.New("remaining time error"), + stakeSnapshot: stakeSnapshotValue("2592145500000000000000000"), + remainingTimeErr: errors.New("remaining time error"), }, want: -1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetUpMockInterfaces() - utilsMock.On("GetStakeSnapshot", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(tt.args.stakeSnapshot, tt.args.stakeSnapshotErr) - cmdUtilsMock.On("IsElectedProposer", mock.Anything, mock.Anything).Return(tt.args.isElectedProposer) - utilsMock.On("GetRemainingTimeOfCurrentState", mock.Anything, mock.Anything).Return(tt.args.remainingTime, tt.args.remainingTimeErr) + utilsMock = new(utilsPkgMocks.Utils) + razorUtils = utilsMock - utils := &UtilsStruct{} + cmdUtils = &UtilsStruct{} - if got := utils.GetIteration(client, proposer, bufferPercent); got != tt.want { + utilsMock.On("GetStakeSnapshot", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(big.NewInt(1).Mul(tt.args.stakeSnapshot, big.NewInt(1e18)), tt.args.stakeSnapshotErr) + utilsMock.On("GetRemainingTimeOfCurrentState", mock.Anything, mock.Anything).Return(tt.args.remainingTime, tt.args.remainingTimeErr) + + if got := cmdUtils.GetIteration(client, proposer, bufferPercent); got != tt.want { t.Errorf("getIteration() = %v, want %v", got, tt.want) } }) @@ -1117,17 +1071,145 @@ func TestGetSortedRevealedValues(t *testing.T) { { name: "Test 1: When GetSortedRevealedValues executes successfully", args: args{ - assignedAssets: []types.RevealedStruct{{RevealedValues: []types.AssignedAsset{{LeafId: 1, Value: big.NewInt(100)}}, Influence: big.NewInt(100)}}, + assignedAssets: []types.RevealedStruct{ + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 3, Value: big.NewInt(601)}, + {LeafId: 6, Value: big.NewInt(750)}, + {LeafId: 1, Value: big.NewInt(400)}, + }, + Influence: big.NewInt(10000000), + }, + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 10, Value: big.NewInt(1100)}, + {LeafId: 5, Value: big.NewInt(900)}, + {LeafId: 7, Value: big.NewInt(302)}, + }, + Influence: big.NewInt(20000000), + }, + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 3, Value: big.NewInt(600)}, + {LeafId: 7, Value: big.NewInt(300)}, + {LeafId: 9, Value: big.NewInt(1600)}, + }, + Influence: big.NewInt(30000000), + }, + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 10, Value: big.NewInt(1105)}, + {LeafId: 8, Value: big.NewInt(950)}, + {LeafId: 7, Value: big.NewInt(300)}, + }, + Influence: big.NewInt(40000000), + }, + }, }, want: &types.RevealedDataMaps{ - SortedRevealedValues: map[uint16][]*big.Int{1: {big.NewInt(100)}}, - VoteWeights: map[string]*big.Int{big.NewInt(100).String(): big.NewInt(100)}, - InfluenceSum: map[uint16]*big.Int{1: big.NewInt(100)}, + SortedRevealedValues: map[uint16][]*big.Int{ + 1: {big.NewInt(400)}, + 3: {big.NewInt(600), big.NewInt(601)}, + 5: {big.NewInt(900)}, + 6: {big.NewInt(750)}, + 7: {big.NewInt(300), big.NewInt(302)}, + 8: {big.NewInt(950)}, + 9: {big.NewInt(1600)}, + 10: {big.NewInt(1100), big.NewInt(1105)}, + }, + VoteWeights: map[string]*big.Int{ + "1600": big.NewInt(30000000), + "300": big.NewInt(70000000), + "302": big.NewInt(20000000), + "400": big.NewInt(10000000), + "600": big.NewInt(30000000), + "601": big.NewInt(10000000), + "750": big.NewInt(10000000), + "1100": big.NewInt(20000000), + "1105": big.NewInt(40000000), + "950": big.NewInt(40000000), + "900": big.NewInt(20000000), + }, + InfluenceSum: map[uint16]*big.Int{ + 1: big.NewInt(10000000), + 3: big.NewInt(40000000), + 5: big.NewInt(20000000), + 6: big.NewInt(10000000), + 7: big.NewInt(90000000), + 8: big.NewInt(40000000), + 9: big.NewInt(30000000), + 10: big.NewInt(60000000), + }, }, wantErr: false, }, { - name: "Test 2: When there is an error in getting assignedAssets", + name: "Test 2: When there are multiple equal and unequal vote values for single leafId", + args: args{ + assignedAssets: []types.RevealedStruct{ + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 1, Value: big.NewInt(600)}, + {LeafId: 2, Value: big.NewInt(750)}, + {LeafId: 3, Value: big.NewInt(400)}, + }, + Influence: big.NewInt(10000000), + }, + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 1, Value: big.NewInt(601)}, + {LeafId: 2, Value: big.NewInt(752)}, + }, + Influence: big.NewInt(20000000), + }, + { + RevealedValues: []types.AssignedAsset{ + {LeafId: 1, Value: big.NewInt(601)}, + {LeafId: 2, Value: big.NewInt(756)}, + {LeafId: 4, Value: big.NewInt(1600)}, + }, + Influence: big.NewInt(30000000), + }, + }, + }, + want: &types.RevealedDataMaps{ + SortedRevealedValues: map[uint16][]*big.Int{ + 1: {big.NewInt(600), big.NewInt(601)}, + 2: {big.NewInt(750), big.NewInt(752), big.NewInt(756)}, + 3: {big.NewInt(400)}, + 4: {big.NewInt(1600)}, + }, + VoteWeights: map[string]*big.Int{ + "1600": big.NewInt(30000000), + "400": big.NewInt(10000000), + "600": big.NewInt(10000000), + "601": big.NewInt(50000000), + "750": big.NewInt(10000000), + "752": big.NewInt(20000000), + "756": big.NewInt(30000000), + }, + InfluenceSum: map[uint16]*big.Int{ + 1: big.NewInt(60000000), + 2: big.NewInt(60000000), + 3: big.NewInt(10000000), + 4: big.NewInt(30000000), + }, + }, + }, + { + name: "Test 3: When assignedAssets is empty", + args: args{ + assignedAssets: []types.RevealedStruct{}, + }, + want: &types.RevealedDataMaps{ + SortedRevealedValues: map[uint16][]*big.Int{}, + VoteWeights: map[string]*big.Int{}, + InfluenceSum: map[uint16]*big.Int{}, + }, + wantErr: false, + }, + { + name: "Test 4: When there is an error in getting assignedAssets", args: args{ assignedAssetsErr: errors.New("error in getting assets"), }, @@ -1240,6 +1322,72 @@ func TestGetSmallestStakeAndId(t *testing.T) { } } +func TestBatchGetStakeCalls(t *testing.T) { + var client *ethclient.Client + var epoch uint32 + + voteManagerABI, _ := abi.JSON(strings.NewReader(bindings.VoteManagerMetaData.ABI)) + + type args struct { + ABI abi.ABI + numberOfStakers uint32 + parseErr error + batchCallResults [][]interface{} + batchCallError error + } + tests := []struct { + name string + args args + wantStakes []*big.Int + wantErr error + }{ + { + name: "Test 1: When BatchGetStakeCalls executes successfully", + args: args{ + ABI: voteManagerABI, + numberOfStakers: 3, + batchCallResults: [][]interface{}{ + {big.NewInt(10)}, + {big.NewInt(11)}, + {big.NewInt(12)}, + }, + }, + wantStakes: []*big.Int{ + big.NewInt(10), + big.NewInt(11), + big.NewInt(12), + }, + wantErr: nil, + }, + { + name: "Test 2: When there is an error in parsing voteManager ABI", + args: args{ + parseErr: errors.New("parse error"), + }, + wantErr: errors.New("parse error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SetUpMockInterfaces() + + abiUtilsMock.On("Parse", mock.Anything).Return(tt.args.ABI, tt.args.parseErr) + clientUtilsMock.On("BatchCall", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.batchCallResults, tt.args.batchCallError) + + ut := &UtilsStruct{} + gotStakes, err := ut.BatchGetStakeSnapshotCalls(client, epoch, tt.args.numberOfStakers) + + if err == nil || tt.wantErr == nil { + assert.Equal(t, tt.wantErr, err) + } else { + assert.EqualError(t, err, tt.wantErr.Error()) + } + + assert.Equal(t, tt.wantStakes, gotStakes) + }) + } +} + func BenchmarkGetIteration(b *testing.B) { var client *ethclient.Client var bufferPercent int32 @@ -1283,7 +1431,6 @@ func BenchmarkGetIteration(b *testing.B) { func BenchmarkGetBiggestStakeAndId(b *testing.B) { var client *ethclient.Client - var address string var epoch uint32 var table = []struct { @@ -1301,12 +1448,10 @@ func BenchmarkGetBiggestStakeAndId(b *testing.B) { SetUpMockInterfaces() utilsMock.On("GetNumberOfStakers", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(v.numOfStakers, nil) - utilsMock.On("GetStakeSnapshot", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint32")).Return(big.NewInt(10000), nil) - utilsMock.On("GetRemainingTimeOfCurrentState", mock.Anything, mock.Anything).Return(int64(150), nil) - cmdUtilsMock.On("GetBufferPercent").Return(int32(60), nil) + cmdUtilsMock.On("BatchGetStakeSnapshotCalls", mock.Anything, mock.Anything, mock.Anything).Return(GenerateDummyStakeSnapshotArray(v.numOfStakers), nil) ut := &UtilsStruct{} - _, _, err := ut.GetBiggestStakeAndId(client, address, epoch) + _, _, err := ut.GetBiggestStakeAndId(client, epoch) if err != nil { log.Fatal(err) } @@ -1387,6 +1532,15 @@ func BenchmarkMakeBlock(b *testing.B) { } } +func GenerateDummyStakeSnapshotArray(numOfStakers uint32) []*big.Int { + stakeSnapshotArray := make([]*big.Int, numOfStakers) + for i := 0; i < int(numOfStakers); i++ { + // For testing purposes, we will assign a stake value of (i + 1) * 1000 + stakeSnapshotArray[i] = big.NewInt(int64(i+1) * 1000) + } + return stakeSnapshotArray +} + func GetDummyVotes(numOfVotes int) []*big.Int { var result []*big.Int for i := 0; i < numOfVotes; i++ { diff --git a/cmd/resetUnstakeLock.go b/cmd/resetUnstakeLock.go index 477e86d7b..3f4d5d474 100644 --- a/cmd/resetUnstakeLock.go +++ b/cmd/resetUnstakeLock.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -50,18 +51,22 @@ func (*UtilsStruct) ExecuteExtendLock(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) stakerId, err := razorUtils.AssignStakerId(flagSet, client, address) utils.CheckError("Error in getting stakerId: ", err) extendLockInput := types.ExtendLockInput{ - Address: address, - Password: password, StakerId: stakerId, + Account: account, } - log.Debugf("ExecuteExtendLock: Calling ResetUnstakeLock with arguments extendLockInput = %+v", extendLockInput) + txn, err := cmdUtils.ResetUnstakeLock(client, config, extendLockInput) utils.CheckError("Error in extending lock: ", err) err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) @@ -72,14 +77,13 @@ func (*UtilsStruct) ExecuteExtendLock(flagSet *pflag.FlagSet) { func (*UtilsStruct) ResetUnstakeLock(client *ethclient.Client, config types.Configurations, extendLockInput types.ExtendLockInput) (common.Hash, error) { txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: extendLockInput.Password, - AccountAddress: extendLockInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.StakeManagerAddress, MethodName: "resetUnstakeLock", Parameters: []interface{}{extendLockInput.StakerId}, ABI: bindings.StakeManagerMetaData.ABI, + Account: extendLockInput.Account, }) log.Info("Extending lock...") diff --git a/cmd/resetUnstakeLock_test.go b/cmd/resetUnstakeLock_test.go index 2755180f3..1d598abdc 100644 --- a/cmd/resetUnstakeLock_test.go +++ b/cmd/resetUnstakeLock_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -19,16 +16,11 @@ import ( ) func TestExtendLock(t *testing.T) { - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31337)) - var extendLockInput types.ExtendLockInput var config types.Configurations var client *ethclient.Client type args struct { - txnOpts *bind.TransactOpts resetLockTxn *Types.Transaction resetLockErr error hash common.Hash @@ -42,7 +34,6 @@ func TestExtendLock(t *testing.T) { { name: "Test 1: When resetLock function executes successfully", args: args{ - txnOpts: txnOpts, resetLockTxn: &Types.Transaction{}, resetLockErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -53,7 +44,6 @@ func TestExtendLock(t *testing.T) { { name: "Test 2: When ResetLock transaction fails", args: args{ - txnOpts: txnOpts, resetLockTxn: &Types.Transaction{}, resetLockErr: errors.New("resetLock error"), hash: common.BigToHash(big.NewInt(1)), @@ -66,7 +56,7 @@ func TestExtendLock(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) stakeManagerMock.On("ResetUnstakeLock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.AnythingOfType("uint32")).Return(tt.args.resetLockTxn, tt.args.resetLockErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -183,7 +173,8 @@ func TestExecuteExtendLock(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) utilsMock.On("AssignStakerId", flagSet, mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) diff --git a/cmd/reveal.go b/cmd/reveal.go index c0d1be8a0..0667c6666 100644 --- a/cmd/reveal.go +++ b/cmd/reveal.go @@ -12,6 +12,7 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" + Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" ) @@ -29,8 +30,8 @@ func (*UtilsStruct) CheckForLastCommitted(client *ethclient.Client, staker bindi } //This function checks if the state is reveal or not and then reveals the votes -func (*UtilsStruct) Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, commitData types.CommitData, signature []byte) (common.Hash, error) { - if state, err := razorUtils.GetBufferedState(client, config.BufferPercent); err != nil || state != 1 { +func (*UtilsStruct) Reveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, latestHeader *Types.Header, commitData types.CommitData, signature []byte) (common.Hash, error) { + if state, err := razorUtils.GetBufferedState(client, latestHeader, config.BufferPercent); err != nil || state != 1 { log.Error("Not reveal state") return core.NilHash, err } @@ -56,14 +57,13 @@ func (*UtilsStruct) Reveal(client *ethclient.Client, config types.Configurations txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.VoteManagerAddress, ABI: bindings.VoteManagerMetaData.ABI, MethodName: "reveal", Parameters: []interface{}{epoch, treeRevealData, signature}, + Account: account, }) log.Debugf("Executing Reveal transaction wih epoch = %d, treeRevealData = %v, signature = %v", epoch, treeRevealData, signature) txn, err := voteManagerUtils.Reveal(client, txnOpts, epoch, treeRevealData, signature) diff --git a/cmd/reveal_test.go b/cmd/reveal_test.go index cc579c4da..c276ade07 100644 --- a/cmd/reveal_test.go +++ b/cmd/reveal_test.go @@ -1,11 +1,8 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" "fmt" - "github.com/ethereum/go-ethereum/crypto" "math/big" "razor/core" "razor/core/types" @@ -14,7 +11,6 @@ import ( "testing" "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -89,15 +85,15 @@ func TestCheckForLastCommitted(t *testing.T) { } func TestReveal(t *testing.T) { - var client *ethclient.Client - var commitData types.CommitData - var signature []byte - var account types.Account - var config types.Configurations - var epoch uint32 - - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) + var ( + client *ethclient.Client + commitData types.CommitData + signature []byte + account types.Account + config types.Configurations + epoch uint32 + latestHeader *Types.Header + ) type args struct { state int64 @@ -105,7 +101,6 @@ func TestReveal(t *testing.T) { merkleTree [][][]byte merkleTreeErr error treeRevealData bindings.StructsMerkleTree - txnOpts *bind.TransactOpts revealTxn *Types.Transaction revealErr error hash common.Hash @@ -121,7 +116,6 @@ func TestReveal(t *testing.T) { args: args{ state: 1, stateErr: nil, - txnOpts: txnOpts, revealTxn: &Types.Transaction{}, revealErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -133,7 +127,6 @@ func TestReveal(t *testing.T) { name: "Test 2: When there is an error in getting state", args: args{ stateErr: errors.New("state error"), - txnOpts: txnOpts, revealTxn: &Types.Transaction{}, revealErr: nil, hash: common.BigToHash(big.NewInt(1)), @@ -146,7 +139,6 @@ func TestReveal(t *testing.T) { args: args{ state: 1, stateErr: nil, - txnOpts: txnOpts, revealTxn: &Types.Transaction{}, revealErr: errors.New("reveal error"), hash: common.BigToHash(big.NewInt(1)), @@ -168,16 +160,16 @@ func TestReveal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("int32")).Return(tt.args.state, tt.args.stateErr) + utilsMock.On("GetBufferedState", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) merkleUtilsMock.On("CreateMerkle", mock.Anything).Return(tt.args.merkleTree, tt.args.merkleTreeErr) cmdUtilsMock.On("GenerateTreeRevealData", mock.Anything, mock.Anything).Return(tt.args.treeRevealData) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(tt.args.txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) voteManagerMock.On("Reveal", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.AnythingOfType("uint32"), mock.Anything, mock.Anything).Return(tt.args.revealTxn, tt.args.revealErr) transactionMock.On("Hash", mock.AnythingOfType("*types.Transaction")).Return(tt.args.hash) utils := &UtilsStruct{} - got, err := utils.Reveal(client, config, account, epoch, commitData, signature) + got, err := utils.Reveal(client, config, account, epoch, latestHeader, commitData, signature) if got != tt.want { t.Errorf("Txn hash for Reveal function, got = %v, want = %v", got, tt.want) } diff --git a/cmd/setConfig.go b/cmd/setConfig.go index 302a787e9..93320e3f3 100644 --- a/cmd/setConfig.go +++ b/cmd/setConfig.go @@ -1,4 +1,4 @@ -//Package cmd provides all functions related to command line +// Package cmd provides all functions related to command line package cmd import ( @@ -27,65 +27,72 @@ Example: }, } -//This function returns the error if there is any and sets the config +// This function returns the error if there is any and sets the config func (*UtilsStruct) SetConfig(flagSet *pflag.FlagSet) error { log.Debug("Checking to assign log file...") fileUtils.AssignLogFile(flagSet, types.Configurations{}) - provider, err := flagSetUtils.GetStringProvider(flagSet) - if err != nil { - return err - } - alternateProvider, err := flagSetUtils.GetStringAlternateProvider(flagSet) - if err != nil { - return err - } - gasMultiplier, err := flagSetUtils.GetFloat32GasMultiplier(flagSet) - if err != nil { - return err - } - bufferPercent, err := flagSetUtils.GetInt32Buffer(flagSet) - if err != nil { - return err - } - waitTime, err := flagSetUtils.GetInt32Wait(flagSet) - if err != nil { - return err - } - gasPrice, err := flagSetUtils.GetInt32GasPrice(flagSet) - if err != nil { - return err - } - logLevel, err := flagSetUtils.GetStringLogLevel(flagSet) - if err != nil { - return err - } - gasLimitOverride, err := flagSetUtils.GetUint64GasLimitOverride(flagSet) - if err != nil { - return err - } - gasLimit, err := flagSetUtils.GetFloat32GasLimit(flagSet) - if err != nil { - return err - } - rpcTimeout, rpcTimeoutErr := flagSetUtils.GetInt64RPCTimeout(flagSet) - if rpcTimeoutErr != nil { - return rpcTimeoutErr + + flagDetails := []types.FlagDetail{ + {Name: "provider", Type: "string"}, + {Name: "alternateProvider", Type: "string"}, + {Name: "gasmultiplier", Type: "float32"}, + {Name: "buffer", Type: "int32"}, + {Name: "wait", Type: "int32"}, + {Name: "gasprice", Type: "int32"}, + {Name: "logLevel", Type: "string"}, + {Name: "gasLimitOverride", Type: "uint64"}, + {Name: "gasLimit", Type: "float32"}, + {Name: "rpcTimeout", Type: "int64"}, + {Name: "httpTimeout", Type: "int64"}, + {Name: "logFileMaxSize", Type: "int"}, + {Name: "logFileMaxBackups", Type: "int"}, + {Name: "logFileMaxAge", Type: "int"}, } - httpTimeout, httpTimeoutErr := flagSetUtils.GetInt64HTTPTimeout(flagSet) - if httpTimeoutErr != nil { - return httpTimeoutErr + + // Storing the fetched flag values in a map + flagValues := make(map[string]interface{}) + for _, flagDetail := range flagDetails { + flagValue, err := flagSetUtils.FetchFlagInput(flagSet, flagDetail.Name, flagDetail.Type) + if err != nil { + log.Errorf("Error in fetching value for flag %v: %v", flagDetail.Name, err) + return err + } + flagValues[flagDetail.Name] = flagValue } - logFileMaxSize, err := flagSetUtils.GetIntLogFileMaxSize(flagSet) - if err != nil { - return err + + configDetails := []types.ConfigDetail{ + {FlagName: "provider", Key: "provider", DefaultValue: ""}, + {FlagName: "alternateProvider", Key: "alternateProvider", DefaultValue: ""}, + {FlagName: "gasmultiplier", Key: "gasmultiplier", DefaultValue: core.DefaultGasMultiplier}, + {FlagName: "buffer", Key: "buffer", DefaultValue: core.DefaultBufferPercent}, + {FlagName: "wait", Key: "wait", DefaultValue: core.DefaultWaitTime}, + {FlagName: "gasprice", Key: "gasprice", DefaultValue: core.DefaultGasPrice}, + {FlagName: "logLevel", Key: "logLevel", DefaultValue: core.DefaultLogLevel}, + {FlagName: "gasLimitOverride", Key: "gasLimitOverride", DefaultValue: core.DefaultGasLimitOverride}, + {FlagName: "gasLimit", Key: "gasLimit", DefaultValue: core.DefaultGasLimit}, + {FlagName: "rpcTimeout", Key: "rpcTimeout", DefaultValue: core.DefaultRPCTimeout}, + {FlagName: "httpTimeout", Key: "httpTimeout", DefaultValue: core.DefaultHTTPTimeout}, + {FlagName: "logFileMaxSize", Key: "logFileMaxSize", DefaultValue: core.DefaultLogFileMaxSize}, + {FlagName: "logFileMaxBackups", Key: "logFileMaxBackups", DefaultValue: core.DefaultLogFileMaxBackups}, + {FlagName: "logFileMaxAge", Key: "logFileMaxAge", DefaultValue: core.DefaultLogFileMaxAge}, } - logFileMaxBackups, err := flagSetUtils.GetIntLogFileMaxBackups(flagSet) - if err != nil { - return err + + var areConfigSet bool + + // Setting the respective config values in config file only if the flag was set with a value in `setConfig` command + for _, configDetail := range configDetails { + if flagValue, exists := flagValues[configDetail.FlagName]; exists { + // Check if the flag was set with a value in `setConfig` command + if flagSetUtils.Changed(flagSet, configDetail.FlagName) { + viper.Set(configDetail.Key, flagValue) + areConfigSet = true + } + } } - logFileMaxAge, err := flagSetUtils.GetIntLogFileMaxAge(flagSet) - if err != nil { - return err + + // If no config parameter was set than all the config parameters will be set to default config values + if !areConfigSet { + setDefaultConfigValues(configDetails) } path, pathErr := pathUtils.GetConfigFilePath() @@ -95,96 +102,47 @@ func (*UtilsStruct) SetConfig(flagSet *pflag.FlagSet) error { } if razorUtils.IsFlagPassed("exposeMetrics") { - port, err := flagSetUtils.GetStringExposeMetrics(flagSet) - if err != nil { - return err - } - - certKey, err := flagSetUtils.GetStringCertKey(flagSet) - if err != nil { - return err - } - certFile, err := flagSetUtils.GetStringCertFile(flagSet) - if err != nil { - return err - } - viper.Set("exposeMetricsPort", port) - - configErr := viperUtils.ViperWriteConfigAs(path) - if configErr != nil { - log.Error("Error in writing config") - return configErr + metricsErr := handleMetrics(flagSet) + if metricsErr != nil { + log.Error("Error in handling metrics: ", metricsErr) + return metricsErr } - - err = metrics.Run(port, certFile, certKey) - if err != nil { - log.Error("Failed to start metrics http server: ", err) - } - } - if provider != "" { - viper.Set("provider", provider) - } - if alternateProvider != "" { - viper.Set("alternateProvider", alternateProvider) - } - if gasMultiplier != -1 { - viper.Set("gasmultiplier", gasMultiplier) - } - if bufferPercent != 0 { - viper.Set("buffer", bufferPercent) - } - if waitTime != -1 { - viper.Set("wait", waitTime) - } - if gasPrice != -1 { - viper.Set("gasprice", gasPrice) - } - if logLevel != "" { - viper.Set("logLevel", logLevel) - } - if gasLimit != -1 { - viper.Set("gasLimit", gasLimit) } - if gasLimitOverride != 0 { - viper.Set("gasLimitOverride", gasLimitOverride) - } - if rpcTimeout != 0 { - viper.Set("rpcTimeout", rpcTimeout) - } - if httpTimeout != 0 { - viper.Set("httpTimeout", httpTimeout) + + configErr := viperUtils.ViperWriteConfigAs(path) + if configErr != nil { + log.Error("Error in writing config: ", configErr) + return configErr } - if logFileMaxSize != 0 { - viper.Set("logFileMaxSize", logFileMaxSize) + return nil +} + +func setDefaultConfigValues(configDetails []types.ConfigDetail) { + log.Info("No value is set to any flag in `setConfig` command") + log.Info("Setting the config values to default. Use `setConfig` again to modify the values.") + for _, configDetail := range configDetails { + viper.Set(configDetail.Key, configDetail.DefaultValue) } - if logFileMaxBackups != 0 { - viper.Set("logFileMaxBackups", logFileMaxBackups) +} + +func handleMetrics(flagSet *pflag.FlagSet) error { + port, err := flagSetUtils.FetchFlagInput(flagSet, "exposeMetrics", "string") + if err != nil { + return err } - if logFileMaxAge != 0 { - viper.Set("logFileMaxAge", logFileMaxAge) + certKey, err := flagSetUtils.FetchFlagInput(flagSet, "certKey", "string") + if err != nil { + return err } - if provider == "" && alternateProvider == "" && gasMultiplier == -1 && bufferPercent == 0 && waitTime == -1 && gasPrice == -1 && logLevel == "" && gasLimit == -1 && gasLimitOverride == 0 && rpcTimeout == 0 && httpTimeout == 0 && logFileMaxSize == 0 && logFileMaxBackups == 0 && logFileMaxAge == 0 { - viper.Set("provider", "") - viper.Set("alternateProvider", "") - viper.Set("gasmultiplier", core.DefaultGasMultiplier) - viper.Set("buffer", core.DefaultBufferPercent) - viper.Set("wait", core.DefaultWaitTime) - viper.Set("gasprice", core.DefaultGasPrice) - viper.Set("logLevel", core.DefaultLogLevel) - viper.Set("gasLimit", core.DefaultGasLimit) - viper.Set("gasLimitOverride", core.DefaultGasLimitOverride) - viper.Set("rpcTimeout", core.DefaultRPCTimeout) - viper.Set("httpTimeout", core.DefaultHTTPTimeout) - viper.Set("logFileMaxSize", core.DefaultLogFileMaxSize) - viper.Set("logFileMaxBackups", core.DefaultLogFileMaxBackups) - viper.Set("logFileMaxAge", core.DefaultLogFileMaxAge) - log.Info("Config values set to default. Use setConfig to modify the values.") + certFile, err := flagSetUtils.FetchFlagInput(flagSet, "certFile", "string") + if err != nil { + return err } + viper.Set("exposeMetricsPort", port) - configErr := viperUtils.ViperWriteConfigAs(path) - if configErr != nil { - log.Error("Error in writing config") - return configErr + err = metrics.Run(port.(string), certFile.(string), certKey.(string)) + if err != nil { + log.Error("Failed to start metrics http server: ", err) } return nil } @@ -215,7 +173,7 @@ func init() { setConfig.Flags().StringVarP(&AlternateProvider, "alternateProvider", "", "", "alternate provider name") setConfig.Flags().Float32VarP(&GasMultiplier, "gasmultiplier", "g", -1, "gas multiplier value") setConfig.Flags().Int32VarP(&BufferPercent, "buffer", "b", 0, "buffer percent") - setConfig.Flags().Int32VarP(&WaitTime, "wait", "w", -1, "wait time (in secs)") + setConfig.Flags().Int32VarP(&WaitTime, "wait", "w", 0, "wait time (in secs)") setConfig.Flags().Int32VarP(&GasPrice, "gasprice", "", -1, "custom gas price") setConfig.Flags().StringVarP(&LogLevel, "logLevel", "", "", "log level") setConfig.Flags().Float32VarP(&GasLimitMultiplier, "gasLimit", "", -1, "gas limit percentage increase") diff --git a/cmd/setConfig_test.go b/cmd/setConfig_test.go index a16a06dba..b25f7a858 100644 --- a/cmd/setConfig_test.go +++ b/cmd/setConfig_test.go @@ -12,44 +12,13 @@ func TestSetConfig(t *testing.T) { var flagSet *pflag.FlagSet type args struct { - provider string - providerErr error - alternateProvider string - alternateProviderErr error - gasmultiplier float32 - gasmultiplierErr error - buffer int32 - bufferErr error - waitTime int32 - waitTimeErr error - gasPrice int32 - gasPriceErr error - logLevel string - logLevelErr error - path string - pathErr error - configErr error - gasLimitMultiplier float32 - gasLimitMultiplierErr error - gasLimitOverride uint64 - gasLimitOverrideErr error - rpcTimeout int64 - rpcTimeoutErr error - httpTimeout int64 - httpTimeoutErr error - isFlagPassed bool - port string - portErr error - certFile string - certFileErr error - certKey string - certKeyErr error - logFileMaxSize int - logFileMaxSizeErr error - logFileMaxBackups int - logFileMaxBackupsErr error - logFileMaxAge int - logFileMaxAgeErr error + flagInput string + flagInputErr error + isFlagPassed bool + path string + pathErr error + isExposeMetricsFlagPassed bool + configErr error } tests := []struct { name string @@ -59,258 +28,57 @@ func TestSetConfig(t *testing.T) { { name: "Test 1: When values are passed to all flags and setConfig returns no error", args: args{ - provider: "http://127.0.0.1", - alternateProvider: "http://127.0.0.1:8545", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - gasLimitMultiplier: 10, - rpcTimeout: 10, - httpTimeout: 20, - logFileMaxSize: 6, - logFileMaxBackups: 11, - logFileMaxAge: 31, + flagInput: "http://127.0.0.1", + path: "/home/config", + isFlagPassed: true, + isExposeMetricsFlagPassed: false, }, wantErr: nil, }, { - name: "Test 2: When parameters are set to default values and setConfig returns no error", + name: "Test 2: When there are no values passed as flag and all config values are default values", args: args{ - provider: "", - gasmultiplier: -1, - buffer: 0, - waitTime: -1, - gasPrice: -1, - logLevel: "", - path: "/home/config", - gasLimitMultiplier: 10, - rpcTimeout: 0, - httpTimeout: 0, + path: "/home/config", + isFlagPassed: false, + isExposeMetricsFlagPassed: false, }, wantErr: nil, }, { - name: "Test 3: When there is an error in getting provider", + name: "Test 3: When there is an error in running metrics server", args: args{ - providerErr: errors.New("provider error"), + flagInput: "8080", + path: "/home/config", + isFlagPassed: true, + isExposeMetricsFlagPassed: true, }, - wantErr: errors.New("provider error"), }, { - name: "Test 4: When there is an error in getting gasmultiplier", + name: "Test 4: When there is an error in getting path", args: args{ - provider: "http://127.0.0.1", - gasmultiplierErr: errors.New("gasmultiplier error"), - }, - wantErr: errors.New("gasmultiplier error"), - }, - { - name: "Test 5: When there is an error in getting buffer", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - bufferErr: errors.New("buffer error"), - }, - wantErr: errors.New("buffer error"), - }, - { - name: "Test 6: When there is an error in getting waitTime", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTimeErr: errors.New("waitTime error"), - }, - wantErr: errors.New("waitTime error"), - }, - { - name: "Test 7: When there is an error in getting gasprice", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPriceErr: errors.New("gasprice error"), - }, - wantErr: errors.New("gasprice error"), - }, - { - name: "Test 8: When there is an error in getting logLevel", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevelErr: errors.New("logLevel error"), - }, - wantErr: errors.New("logLevel error"), - }, - { - name: "Test 9: When there is an error in getting path", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - pathErr: errors.New("path error"), + pathErr: errors.New("path error"), }, wantErr: errors.New("path error"), }, { - name: "Test 10: When there is an error in writing config", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - configErr: errors.New("writing config error"), - }, - wantErr: errors.New("writing config error"), - }, - { - name: "Test 11: When only one of the flags is passed", - args: args{ - gasmultiplier: 2, - path: "/home/config", - configErr: nil, - }, - wantErr: nil, - }, - { - name: "Test 12: When there is an error in getting gas limit", + name: "Test 5: When there is an error in writing config", args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - gasLimitMultiplier: -1, - gasLimitMultiplierErr: errors.New("gasLimitMultiplier error"), - }, - wantErr: errors.New("gasLimitMultiplier error"), - }, - { - name: "Test 13: When default nil values are passed", - args: args{ - provider: "", - gasmultiplier: -1, - buffer: 0, - waitTime: -1, - gasPrice: -1, - logLevel: "", - rpcTimeout: 0, - httpTimeout: 0, - path: "/home/config", - gasLimitMultiplier: -1, - }, - wantErr: nil, - }, - { - name: "Test 14: When exposeMetrics flag is passed", - args: args{ - isFlagPassed: true, - port: "", - configErr: errors.New("config error"), + path: "/home/config", + isFlagPassed: false, + isExposeMetricsFlagPassed: false, + configErr: errors.New("config error"), }, wantErr: errors.New("config error"), }, - { - name: "Test 15: When there is an error in getting port", - args: args{ - isFlagPassed: true, - portErr: errors.New("error in getting port"), - }, - wantErr: errors.New("error in getting port"), - }, - { - name: "Test 16: When there is an error in getting RPC timeout", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - gasLimitMultiplier: -1, - rpcTimeoutErr: errors.New("rpcTimeout error"), - }, - wantErr: errors.New("rpcTimeout error"), - }, - { - name: "Test 17: When there is an error in getting gas limit to overrride", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - gasLimitMultiplier: -1, - gasLimitOverrideErr: errors.New("gasLimitOverride error"), - }, - wantErr: errors.New("gasLimitOverride error"), - }, - { - name: "Test 18: When there is an error in getting HTTP timeout", - args: args{ - provider: "http://127.0.0.1", - gasmultiplier: 2, - buffer: 20, - waitTime: 2, - gasPrice: 1, - logLevel: "debug", - path: "/home/config", - gasLimitMultiplier: -1, - rpcTimeout: 10, - httpTimeoutErr: errors.New("httpTimeout error"), - }, - wantErr: errors.New("httpTimeout error"), - }, - { - name: "Test 18: When there is an error in getting alternate provider", - args: args{ - provider: "http://127.0.0.1", - alternateProviderErr: errors.New("alternate provider error"), - }, - wantErr: errors.New("alternate provider error"), - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) - flagSetMock.On("GetStringProvider", flagSet).Return(tt.args.provider, tt.args.providerErr) - flagSetMock.On("GetStringAlternateProvider", flagSet).Return(tt.args.alternateProvider, tt.args.alternateProviderErr) - flagSetMock.On("GetFloat32GasMultiplier", flagSet).Return(tt.args.gasmultiplier, tt.args.gasmultiplierErr) - flagSetMock.On("GetInt32Buffer", flagSet).Return(tt.args.buffer, tt.args.bufferErr) - flagSetMock.On("GetInt32Wait", flagSet).Return(tt.args.waitTime, tt.args.waitTimeErr) - flagSetMock.On("GetInt32GasPrice", flagSet).Return(tt.args.gasPrice, tt.args.gasPriceErr) - flagSetMock.On("GetStringLogLevel", flagSet).Return(tt.args.logLevel, tt.args.logLevelErr) - flagSetMock.On("GetFloat32GasLimit", flagSet).Return(tt.args.gasLimitMultiplier, tt.args.gasLimitMultiplierErr) - flagSetMock.On("GetUint64GasLimitOverride", flagSet).Return(tt.args.gasLimitOverride, tt.args.gasLimitOverrideErr) - flagSetMock.On("GetInt64RPCTimeout", flagSet).Return(tt.args.rpcTimeout, tt.args.rpcTimeoutErr) - flagSetMock.On("GetInt64HTTPTimeout", flagSet).Return(tt.args.httpTimeout, tt.args.httpTimeoutErr) - flagSetMock.On("GetStringExposeMetrics", flagSet).Return(tt.args.port, tt.args.portErr) - flagSetMock.On("GetStringCertFile", flagSet).Return(tt.args.certFile, tt.args.certFileErr) - flagSetMock.On("GetStringCertKey", flagSet).Return(tt.args.certKey, tt.args.certKeyErr) - flagSetMock.On("GetIntLogFileMaxSize", mock.Anything).Return(tt.args.logFileMaxSize, tt.args.logFileMaxSizeErr) - flagSetMock.On("GetIntLogFileMaxBackups", mock.Anything).Return(tt.args.logFileMaxBackups, tt.args.logFileMaxBackupsErr) - flagSetMock.On("GetIntLogFileMaxAge", mock.Anything).Return(tt.args.logFileMaxAge, tt.args.logFileMaxAgeErr) - utilsMock.On("IsFlagPassed", mock.Anything).Return(tt.args.isFlagPassed) + flagSetMock.On("FetchFlagInput", flagSet, mock.Anything, mock.Anything).Return(tt.args.flagInput, tt.args.flagInputErr) + flagSetMock.On("Changed", mock.Anything, mock.Anything).Return(tt.args.isFlagPassed) + utilsMock.On("IsFlagPassed", mock.Anything).Return(tt.args.isExposeMetricsFlagPassed) pathMock.On("GetConfigFilePath").Return(tt.args.path, tt.args.pathErr) viperMock.On("ViperWriteConfigAs", mock.AnythingOfType("string")).Return(tt.args.configErr) diff --git a/cmd/setDelegation.go b/cmd/setDelegation.go index 30b41fe43..f28ca97d3 100644 --- a/cmd/setDelegation.go +++ b/cmd/setDelegation.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -50,7 +51,12 @@ func (*UtilsStruct) ExecuteSetDelegation(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) statusString, err := flagSetUtils.GetStringStatus(flagSet) @@ -66,15 +72,13 @@ func (*UtilsStruct) ExecuteSetDelegation(flagSet *pflag.FlagSet) { utils.CheckError("Error in fetching commission: ", err) delegationInput := types.SetDelegationInput{ - Address: address, - Password: password, Status: status, StatusString: statusString, StakerId: stakerId, Commission: commission, + Account: account, } - log.Debugf("ExecuteSetDelegation: Calling SetDelegation() with argument delegationInput = %+v", delegationInput) txn, err := cmdUtils.SetDelegation(client, config, delegationInput) utils.CheckError("SetDelegation error: ", err) if txn != core.NilHash { @@ -93,11 +97,9 @@ func (*UtilsStruct) SetDelegation(client *ethclient.Client, config types.Configu if delegationInput.Commission != 0 { updateCommissionInput := types.UpdateCommissionInput{ StakerId: delegationInput.StakerId, - Address: delegationInput.Address, - Password: delegationInput.Password, Commission: delegationInput.Commission, + Account: delegationInput.Account, } - log.Debugf("Calling UpdateCommission() with argument updateCommissionInput = %+v", updateCommissionInput) err = cmdUtils.UpdateCommission(config, client, updateCommissionInput) if err != nil { return core.NilHash, err @@ -106,14 +108,13 @@ func (*UtilsStruct) SetDelegation(client *ethclient.Client, config types.Configu txnOpts := types.TransactionOptions{ Client: client, - Password: delegationInput.Password, - AccountAddress: delegationInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.StakeManagerAddress, ABI: bindings.StakeManagerMetaData.ABI, MethodName: "setDelegationAcceptance", Parameters: []interface{}{delegationInput.Status}, + Account: delegationInput.Account, } if stakerInfo.AcceptDelegation == delegationInput.Status { diff --git a/cmd/setDelegation_test.go b/cmd/setDelegation_test.go index b852eb7f1..e9cc4b2b1 100644 --- a/cmd/setDelegation_test.go +++ b/cmd/setDelegation_test.go @@ -1,11 +1,9 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/cmd/mocks" "razor/core" "razor/core/types" @@ -13,7 +11,6 @@ import ( utilsPkgMocks "razor/utils/mocks" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -31,12 +28,8 @@ func TestSetDelegation(t *testing.T) { WaitTime: 1, } - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { status bool - txnOpts *bind.TransactOpts staker bindings.StructsStaker stakerErr error setDelegationAcceptanceTxn *Types.Transaction @@ -54,7 +47,6 @@ func TestSetDelegation(t *testing.T) { { name: "Test 1: When SetDelegation function executes successfully", args: args{ - txnOpts: txnOpts, staker: bindings.StructsStaker{ AcceptDelegation: true, }, @@ -69,7 +61,6 @@ func TestSetDelegation(t *testing.T) { { name: "Test 2: When setDelegationAcceptance transaction fails", args: args{ - txnOpts: txnOpts, staker: bindings.StructsStaker{ AcceptDelegation: true, }, @@ -84,7 +75,6 @@ func TestSetDelegation(t *testing.T) { { name: "Test 3: When there is an error in getting staker", args: args{ - txnOpts: txnOpts, stakerErr: errors.New("staker error"), setDelegationAcceptanceTxn: &Types.Transaction{}, setDelegationAcceptanceErr: nil, @@ -96,8 +86,7 @@ func TestSetDelegation(t *testing.T) { { name: "Test 4: When stakerInfo.AcceptDelegation == delegationInput.Status", args: args{ - status: true, - txnOpts: txnOpts, + status: true, staker: bindings.StructsStaker{ AcceptDelegation: true, }, @@ -112,7 +101,6 @@ func TestSetDelegation(t *testing.T) { { name: "Test 5: When commission is non zero and UpdateCommission executes successfully", args: args{ - txnOpts: txnOpts, staker: bindings.StructsStaker{ AcceptDelegation: true, }, @@ -129,7 +117,6 @@ func TestSetDelegation(t *testing.T) { { name: "Test 6: When commission is non zero and UpdateCommission does not executes successfully", args: args{ - txnOpts: txnOpts, staker: bindings.StructsStaker{ AcceptDelegation: true, }, @@ -159,7 +146,7 @@ func TestSetDelegation(t *testing.T) { utilsMock.On("GetStaker", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.staker, tt.args.stakerErr) cmdUtilsMock.On("UpdateCommission", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.UpdateCommissionErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) stakeManagerUtilsMock.On("SetDelegationAcceptance", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.AnythingOfType("bool")).Return(tt.args.setDelegationAcceptanceTxn, tt.args.setDelegationAcceptanceErr) transactionUtilsMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -418,7 +405,8 @@ func TestExecuteSetDelegation(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetUtilsMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) flagSetUtilsMock.On("GetStringStatus", flagSet).Return(tt.args.status, tt.args.statusErr) flagSetUtilsMock.On("GetUint8Commission", flagSet).Return(tt.args.commission, tt.args.commissionErr) diff --git a/cmd/struct-utils.go b/cmd/struct-utils.go index b5e76e457..eebc2712f 100644 --- a/cmd/struct-utils.go +++ b/cmd/struct-utils.go @@ -1,11 +1,11 @@ -//Package cmd provides all functions related to command line +// Package cmd provides all functions related to command line package cmd import ( "crypto/ecdsa" + "errors" "math/big" "os" - "razor/accounts" "razor/core" "razor/core/types" "razor/path" @@ -28,13 +28,12 @@ import ( ) var ( - razorUtils = utils.UtilsInterface - pathUtils = path.PathUtilsInterface - clientUtils = utils.ClientInterface - fileUtils = utils.FileInterface - gasUtils = utils.GasInterface - merkleUtils = utils.MerkleInterface - accountUtils = accounts.AccountUtilsInterface + razorUtils = utils.UtilsInterface + pathUtils = path.PathUtilsInterface + clientUtils = utils.ClientInterface + fileUtils = utils.FileInterface + gasUtils = utils.GasInterface + merkleUtils = utils.MerkleInterface ) //This function initializes the utils @@ -51,7 +50,6 @@ func InitializeUtils() { utils.ABIInterface = &utils.ABIStruct{} utils.PathInterface = &utils.PathStruct{} utils.BindInterface = &utils.BindStruct{} - utils.AccountsInterface = &utils.AccountsStruct{} utils.BlockManagerInterface = &utils.BlockManagerStruct{} utils.StakeManagerInterface = &utils.StakeManagerStruct{} utils.AssetManagerInterface = &utils.AssetManagerStruct{} @@ -70,8 +68,6 @@ func InitializeUtils() { utils.GasInterface = &utils.GasStruct{} merkleUtils = &utils.MerkleTreeStruct{} utils.MerkleInterface = &utils.MerkleTreeStruct{} - accountUtils = &accounts.AccountUtils{} - accounts.AccountUtilsInterface = &accounts.AccountUtils{} } func ExecuteTransaction(interfaceName interface{}, methodName string, args ...interface{}) (*Types.Transaction, error) { @@ -83,6 +79,55 @@ func ExecuteTransaction(interfaceName interface{}, methodName string, args ...in return returnedValues[0].Interface().(*Types.Transaction), nil } +// FetchFlagInput fetches input value of the flag with given data type and specified flag keyword +func (flagSetUtils FLagSetUtils) FetchFlagInput(flagSet *pflag.FlagSet, flagName string, dataType string) (interface{}, error) { + switch dataType { + case "string": + return flagSet.GetString(flagName) + case "float32": + return flagSet.GetFloat32(flagName) + case "int32": + return flagSet.GetInt32(flagName) + case "int64": + return flagSet.GetInt64(flagName) + case "uint64": + return flagSet.GetUint64(flagName) + case "int": + return flagSet.GetInt(flagName) + case "bool": + return flagSet.GetBool(flagName) + default: + return nil, errors.New("unsupported data type for flag input") + } +} + +// FetchRootFlagInput fetches input value of the root flag with given data type and specified flag keyword +func (flagSetUtils FLagSetUtils) FetchRootFlagInput(flagName string, dataType string) (interface{}, error) { + switch dataType { + case "string": + return rootCmd.PersistentFlags().GetString(flagName) + case "float32": + return rootCmd.PersistentFlags().GetFloat32(flagName) + case "int32": + return rootCmd.PersistentFlags().GetInt32(flagName) + case "int64": + return rootCmd.PersistentFlags().GetInt64(flagName) + case "uint64": + return rootCmd.PersistentFlags().GetUint64(flagName) + case "int": + return rootCmd.PersistentFlags().GetInt(flagName) + case "bool": + return rootCmd.PersistentFlags().GetBool(flagName) + default: + return nil, errors.New("unsupported data type for root flag input") + } +} + +// Changed returns true if flag was passed in the command else returns false +func (flagSetUtils FLagSetUtils) Changed(flagSet *pflag.FlagSet, flagName string) bool { + return flagSet.Changed(flagName) +} + //This function returns the hash func (transactionUtils TransactionUtils) Hash(txn *Types.Transaction) common.Hash { return txn.Hash() @@ -213,7 +258,7 @@ func (blockManagerUtils BlockManagerUtils) ClaimBlockReward(client *ethclient.Cl return ExecuteTransaction(blockManager, "ClaimBlockReward", opts) } -//Thid function is used to finalize the dispute +// Thid function is used to finalize the dispute func (blockManagerUtils BlockManagerUtils) FinalizeDispute(client *ethclient.Client, opts *bind.TransactOpts, epoch uint32, blockIndex uint8, positionOfCollectionInBlock *big.Int) (*Types.Transaction, error) { blockManager := razorUtils.GetBlockManager(client) var ( @@ -350,7 +395,7 @@ func (blockManagerUtils BlockManagerUtils) ResetDispute(blockManager *bindings.B return ExecuteTransaction(blockManager, "ResetDispute", opts, epoch) } -//This functiom gets Disputes mapping +// This functiom gets Disputes mapping func (blockManagerUtils BlockManagerUtils) Disputes(client *ethclient.Client, opts *bind.CallOpts, epoch uint32, address common.Address) (types.DisputesStruct, error) { blockManager := razorUtils.GetBlockManager(client) returnedValues := utils.InvokeFunctionWithTimeout(blockManager, "Disputes", opts, epoch, address) @@ -473,135 +518,11 @@ func (assetManagerUtils AssetManagerUtils) UpdateCollection(client *ethclient.Cl return ExecuteTransaction(assetManager, "UpdateCollection", opts, collectionId, tolerance, aggregationMethod, power, jobIds) } -//This function returns the provider in string -func (flagSetUtils FLagSetUtils) GetStringProvider(flagSet *pflag.FlagSet) (string, error) { - return flagSet.GetString("provider") -} - -func (flagSetUtils FLagSetUtils) GetStringAlternateProvider(flagSet *pflag.FlagSet) (string, error) { - return flagSet.GetString("alternateProvider") -} - -//This function returns gas multiplier in float 32 -func (flagSetUtils FLagSetUtils) GetFloat32GasMultiplier(flagSet *pflag.FlagSet) (float32, error) { - return flagSet.GetFloat32("gasmultiplier") -} - -//This function returns Buffer in Int32 -func (flagSetUtils FLagSetUtils) GetInt32Buffer(flagSet *pflag.FlagSet) (int32, error) { - return flagSet.GetInt32("buffer") -} - -//This function returns Wait in Int32 -func (flagSetUtils FLagSetUtils) GetInt32Wait(flagSet *pflag.FlagSet) (int32, error) { - return flagSet.GetInt32("wait") -} - -//This function returns GasPrice in Int32 -func (flagSetUtils FLagSetUtils) GetInt32GasPrice(flagSet *pflag.FlagSet) (int32, error) { - return flagSet.GetInt32("gasprice") -} - -//This function returns Log Level in string -func (flagSetUtils FLagSetUtils) GetStringLogLevel(flagSet *pflag.FlagSet) (string, error) { - return flagSet.GetString("logLevel") -} - -//This function returns RPC Timeout in Int64 -func (flagSetUtils FLagSetUtils) GetInt64RPCTimeout(flagSet *pflag.FlagSet) (int64, error) { - return flagSet.GetInt64("rpcTimeout") -} - -//This function returns GasLimit to override in Uint64 -func (flagSetUtils FLagSetUtils) GetUint64GasLimitOverride(flagSet *pflag.FlagSet) (uint64, error) { - return flagSet.GetUint64("gasLimitOverride") -} - -//This function returns HTTP Timeout in Int64 -func (flagSetUtils FLagSetUtils) GetInt64HTTPTimeout(flagSet *pflag.FlagSet) (int64, error) { - return flagSet.GetInt64("httpTimeout") -} - -//This function returns Gas Limit in Float32 -func (flagSetUtils FLagSetUtils) GetFloat32GasLimit(flagSet *pflag.FlagSet) (float32, error) { - return flagSet.GetFloat32("gasLimit") -} - //This function returns BountyId in Uint32 func (flagSetUtils FLagSetUtils) GetUint32BountyId(flagSet *pflag.FlagSet) (uint32, error) { return flagSet.GetUint32("bountyId") } -//This function returns the provider of root in string -func (flagSetUtils FLagSetUtils) GetRootStringProvider() (string, error) { - return rootCmd.PersistentFlags().GetString("provider") -} - -//This function returns the alternate provider of root in string -func (flagSetUtils FLagSetUtils) GetRootStringAlternateProvider() (string, error) { - return rootCmd.PersistentFlags().GetString("alternateProvider") -} - -//This function returns the gas multiplier of root in float32 -func (flagSetUtils FLagSetUtils) GetRootFloat32GasMultiplier() (float32, error) { - return rootCmd.PersistentFlags().GetFloat32("gasmultiplier") -} - -//This function returns the buffer of root in Int32 -func (flagSetUtils FLagSetUtils) GetRootInt32Buffer() (int32, error) { - return rootCmd.PersistentFlags().GetInt32("buffer") -} - -//This function returns the wait of root in Int32 -func (flagSetUtils FLagSetUtils) GetRootInt32Wait() (int32, error) { - return rootCmd.PersistentFlags().GetInt32("wait") -} - -//This function returns the gas price of root in Int32 -func (flagSetUtils FLagSetUtils) GetRootInt32GasPrice() (int32, error) { - return rootCmd.PersistentFlags().GetInt32("gasprice") -} - -//This function returns the log level of root in string -func (flagSetUtils FLagSetUtils) GetRootStringLogLevel() (string, error) { - return rootCmd.PersistentFlags().GetString("logLevel") -} - -//This function returns the gas limit of root in Float32 -func (flagSetUtils FLagSetUtils) GetRootFloat32GasLimit() (float32, error) { - return rootCmd.PersistentFlags().GetFloat32("gasLimit") -} - -//This function returns the gas limit to overridr of root in Uint64 -func (flagSetUtils FLagSetUtils) GetRootUint64GasLimitOverride() (uint64, error) { - return rootCmd.PersistentFlags().GetUint64("gasLimitOverride") -} - -//This function returns the rpcTimeout of root in Int64 -func (flagSetUtils FLagSetUtils) GetRootInt64RPCTimeout() (int64, error) { - return rootCmd.PersistentFlags().GetInt64("rpcTimeout") -} - -//This function returns the HTTPTimeout of root in Int64 -func (flagSetUtils FLagSetUtils) GetRootInt64HTTPTimeout() (int64, error) { - return rootCmd.PersistentFlags().GetInt64("httpTimeout") -} - -//This function returns the max size of log file for root flag in Int -func (flagSetUtils FLagSetUtils) GetRootIntLogFileMaxSize() (int, error) { - return rootCmd.PersistentFlags().GetInt("logFileMaxSize") -} - -//This function returns the max number of backups for logFile for root flag in Int -func (flagSetUtils FLagSetUtils) GetRootIntLogFileMaxBackups() (int, error) { - return rootCmd.PersistentFlags().GetInt("logFileMaxBackups") -} - -//This function returns the max age of logFle for root file in Int -func (flagSetUtils FLagSetUtils) GetRootIntLogFileMaxAge() (int, error) { - return rootCmd.PersistentFlags().GetInt("logFileMaxAge") -} - //This function returns the from in string func (flagSetUtils FLagSetUtils) GetStringFrom(flagSet *pflag.FlagSet) (string, error) { from, err := flagSet.GetString("from") diff --git a/cmd/transfer.go b/cmd/transfer.go index b467ee422..bbb7517b1 100644 --- a/cmd/transfer.go +++ b/cmd/transfer.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -50,7 +51,12 @@ func (*UtilsStruct) ExecuteTransfer(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(fromAddress, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(fromAddress, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) toAddress, err := flagSetUtils.GetStringTo(flagSet) @@ -64,14 +70,12 @@ func (*UtilsStruct) ExecuteTransfer(flagSet *pflag.FlagSet) { utils.CheckError("Error in getting amount: ", err) transferInput := types.TransferInput{ - FromAddress: fromAddress, - ToAddress: toAddress, - Password: password, - ValueInWei: valueInWei, - Balance: balance, + ToAddress: toAddress, + ValueInWei: valueInWei, + Balance: balance, + Account: account, } - log.Debugf("Calling Transfer() with arguments transferInput = %+v", transferInput) txn, err := cmdUtils.Transfer(client, config, transferInput) utils.CheckError("Transfer error: ", err) @@ -86,16 +90,15 @@ func (*UtilsStruct) Transfer(client *ethclient.Client, config types.Configuratio txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: transferInput.Password, - AccountAddress: transferInput.FromAddress, ChainId: core.ChainId, Config: config, ContractAddress: core.RAZORAddress, MethodName: "transfer", Parameters: []interface{}{common.HexToAddress(transferInput.ToAddress), transferInput.ValueInWei}, ABI: bindings.RAZORMetaData.ABI, + Account: transferInput.Account, }) - log.Infof("Transferring %g tokens from %s to %s", utils.GetAmountInDecimal(transferInput.ValueInWei), transferInput.FromAddress, transferInput.ToAddress) + log.Infof("Transferring %g tokens from %s to %s", utils.GetAmountInDecimal(transferInput.ValueInWei), transferInput.Account.Address, transferInput.ToAddress) log.Debugf("Executing Transfer transaction with toAddress: %s, amount: %s", transferInput.ToAddress, transferInput.ValueInWei) txn, err := tokenManagerUtils.Transfer(client, txnOpts, common.HexToAddress(transferInput.ToAddress), transferInput.ValueInWei) diff --git a/cmd/transfer_test.go b/cmd/transfer_test.go index 756a409df..d0e04e651 100644 --- a/cmd/transfer_test.go +++ b/cmd/transfer_test.go @@ -1,18 +1,15 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" "github.com/stretchr/testify/mock" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -23,13 +20,9 @@ func TestTransfer(t *testing.T) { var client *ethclient.Client var config types.Configurations - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(31000)) - type args struct { amount *big.Int decimalAmount *big.Float - txnOpts *bind.TransactOpts transferTxn *Types.Transaction transferErr error transferHash common.Hash @@ -45,7 +38,6 @@ func TestTransfer(t *testing.T) { args: args{ amount: big.NewInt(1).Mul(big.NewInt(1000), big.NewInt(1e18)), decimalAmount: big.NewFloat(1000), - txnOpts: txnOpts, transferTxn: &Types.Transaction{}, transferErr: nil, transferHash: common.BigToHash(big.NewInt(1)), @@ -58,7 +50,6 @@ func TestTransfer(t *testing.T) { args: args{ amount: big.NewInt(1).Mul(big.NewInt(1000), big.NewInt(1e18)), decimalAmount: big.NewFloat(1000), - txnOpts: txnOpts, transferTxn: &Types.Transaction{}, transferErr: errors.New("transfer error"), transferHash: common.BigToHash(big.NewInt(1)), @@ -72,7 +63,7 @@ func TestTransfer(t *testing.T) { SetUpMockInterfaces() utilsMock.On("CheckAmountAndBalance", mock.AnythingOfType("*big.Int"), mock.AnythingOfType("*big.Int")).Return(tt.args.amount) - utilsMock.On("GetTxnOpts", mock.Anything).Return(tt.args.txnOpts) + utilsMock.On("GetTxnOpts", mock.Anything).Return(TxnOpts) utilsMock.On("GetAmountInDecimal", mock.AnythingOfType("*big.Int")).Return(tt.args.decimalAmount) tokenManagerMock.On("Transfer", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.AnythingOfType("common.Address"), mock.AnythingOfType("*big.Int")).Return(tt.args.transferTxn, tt.args.transferErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.transferHash) @@ -255,7 +246,8 @@ func TestExecuteTransfer(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringFrom", flagSet).Return(tt.args.from, tt.args.fromErr) flagSetMock.On("GetStringTo", flagSet).Return(tt.args.to, tt.args.toErr) cmdUtilsMock.On("AssignAmountInWei", flagSet).Return(tt.args.amount, tt.args.amountErr) diff --git a/cmd/unlockWithdraw.go b/cmd/unlockWithdraw.go index 7657beb24..5edd33c91 100644 --- a/cmd/unlockWithdraw.go +++ b/cmd/unlockWithdraw.go @@ -4,6 +4,7 @@ package cmd import ( "errors" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -50,20 +51,21 @@ func (*UtilsStruct) ExecuteUnlockWithdraw(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) stakerId, err := razorUtils.AssignStakerId(flagSet, client, address) utils.CheckError("Error in fetching stakerId: ", err) log.Debug("ExecuteUnlockWithdraw: StakerId: ", stakerId) - log.Debugf("ExecuteUnlockWithdraw: Calling HandleWithdrawLock with arguments account address = %s, stakerId = %d", address, stakerId) - txn, err := cmdUtils.HandleWithdrawLock(client, types.Account{ - Address: address, - Password: password, - }, config, stakerId) + txn, err := cmdUtils.HandleWithdrawLock(client, account, config, stakerId) + utils.CheckError("HandleWithdrawLock error: ", err) - utils.CheckError("UnlockWithdraw error: ", err) if txn != core.NilHash { err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) utils.CheckError("Error in WaitForBlockCompletion for unlockWithdraw: ", err) @@ -100,14 +102,13 @@ func (*UtilsStruct) HandleWithdrawLock(client *ethclient.Client, account types.A if big.NewInt(int64(epoch)).Cmp(withdrawLock.UnlockAfter) >= 0 { txnArgs := types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: configurations, ContractAddress: core.StakeManagerAddress, MethodName: "unlockWithdraw", ABI: bindings.StakeManagerMetaData.ABI, Parameters: []interface{}{stakerId}, + Account: account, } txnOpts := razorUtils.GetTxnOpts(txnArgs) log.Debug("HandleWithdrawLock: Calling UnlockWithdraw() with arguments stakerId = ", stakerId) diff --git a/cmd/unlockWithdraw_test.go b/cmd/unlockWithdraw_test.go index 40bd5defa..210613e20 100644 --- a/cmd/unlockWithdraw_test.go +++ b/cmd/unlockWithdraw_test.go @@ -1,11 +1,9 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "reflect" @@ -100,7 +98,8 @@ func TestExecuteUnlockWithdraw(t *testing.T) { flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) utilsMock.On("AssignStakerId", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.stakerId, tt.args.stakerIdErr) cmdUtilsMock.On("HandleWithdrawLock", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything).Return(tt.args.txn, tt.args.err) @@ -115,9 +114,6 @@ func TestExecuteUnlockWithdraw(t *testing.T) { } func TestHandleWithdrawLock(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var ( client *ethclient.Client account types.Account @@ -130,7 +126,6 @@ func TestHandleWithdrawLock(t *testing.T) { withdrawLockErr error epoch uint32 epochErr error - txnOpts *bind.TransactOpts unlockWithdraw common.Hash unlockWithdrawErr error time string @@ -148,7 +143,6 @@ func TestHandleWithdrawLock(t *testing.T) { UnlockAfter: big.NewInt(4), }, epoch: 5, - txnOpts: txnOpts, unlockWithdraw: common.BigToHash(big.NewInt(1)), }, want: common.BigToHash(big.NewInt(1)), @@ -203,7 +197,7 @@ func TestHandleWithdrawLock(t *testing.T) { utilsMock.On("GetLock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32"), mock.Anything).Return(tt.args.withdrawLock, tt.args.withdrawLockErr) utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.epochErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("UnlockWithdraw", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.unlockWithdraw, tt.args.unlockWithdrawErr) utilsMock.On("SecondsToReadableTime", mock.AnythingOfType("int")).Return(tt.args.time) ut := &UtilsStruct{} diff --git a/cmd/unstake.go b/cmd/unstake.go index 40c6abbe0..af79f2591 100644 --- a/cmd/unstake.go +++ b/cmd/unstake.go @@ -4,6 +4,7 @@ package cmd import ( "errors" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -53,7 +54,12 @@ func (*UtilsStruct) ExecuteUnstake(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) log.Debug("Getting amount in wei...") @@ -64,13 +70,11 @@ func (*UtilsStruct) ExecuteUnstake(flagSet *pflag.FlagSet) { utils.CheckError("StakerId error: ", err) unstakeInput := types.UnstakeInput{ - Address: address, - Password: password, ValueInWei: valueInWei, StakerId: stakerId, + Account: account, } - log.Debugf("ExecuteUnstake: Calling Unstake() with arguments unstakeInput: %+v", unstakeInput) txnHash, err := cmdUtils.Unstake(config, client, unstakeInput) utils.CheckError("Unstake Error: ", err) if txnHash != core.NilHash { @@ -82,12 +86,11 @@ func (*UtilsStruct) ExecuteUnstake(flagSet *pflag.FlagSet) { //This function allows user to unstake their sRZRs in the razor network func (*UtilsStruct) Unstake(config types.Configurations, client *ethclient.Client, input types.UnstakeInput) (common.Hash, error) { txnArgs := types.TransactionOptions{ - Client: client, - Password: input.Password, - AccountAddress: input.Address, - Amount: input.ValueInWei, - ChainId: core.ChainId, - Config: config, + Client: client, + Amount: input.ValueInWei, + ChainId: core.ChainId, + Config: config, + Account: input.Account, } stakerId := input.StakerId staker, err := razorUtils.GetStaker(client, stakerId) @@ -116,7 +119,7 @@ func (*UtilsStruct) Unstake(config types.Configurations, client *ethclient.Clien txnArgs.MethodName = "unstake" txnArgs.ABI = bindings.StakeManagerMetaData.ABI - unstakeLock, err := razorUtils.GetLock(txnArgs.Client, txnArgs.AccountAddress, stakerId, 0) + unstakeLock, err := razorUtils.GetLock(txnArgs.Client, txnArgs.Account.Address, stakerId, 0) if err != nil { log.Error("Error in getting unstakeLock: ", err) return core.NilHash, err diff --git a/cmd/unstake_test.go b/cmd/unstake_test.go index 7db9851b8..3f8008dfe 100644 --- a/cmd/unstake_test.go +++ b/cmd/unstake_test.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "razor/pkg/bindings" @@ -21,13 +22,9 @@ import ( ) func TestUnstake(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var config types.Configurations var client *ethclient.Client - var address string - var password string + var account types.Account var stakerId uint32 type args struct { @@ -130,15 +127,14 @@ func TestUnstake(t *testing.T) { utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(nil) utilsMock.On("GetLock", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string"), mock.AnythingOfType("uint32"), mock.Anything).Return(tt.args.lock, tt.args.lockErr) cmdUtilsMock.On("WaitForAppropriateState", mock.AnythingOfType("*ethclient.Client"), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) stakeManagerMock.On("Unstake", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.unstakeTxn, tt.args.unstakeErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) utils := &UtilsStruct{} _, gotErr := utils.Unstake(config, client, types.UnstakeInput{ - Address: address, - Password: password, + Account: account, StakerId: stakerId, ValueInWei: tt.args.amount, }) @@ -297,7 +293,8 @@ func TestExecuteUnstake(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) cmdUtilsMock.On("AssignAmountInWei", flagSet).Return(tt.args.value, tt.args.valueErr) diff --git a/cmd/updateCollection.go b/cmd/updateCollection.go index a73c92c08..10b249064 100644 --- a/cmd/updateCollection.go +++ b/cmd/updateCollection.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -52,7 +53,12 @@ func (*UtilsStruct) ExecuteUpdateCollection(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) collectionId, err := flagSetUtils.GetUint16CollectionId(flagSet) @@ -71,14 +77,12 @@ func (*UtilsStruct) ExecuteUpdateCollection(flagSet *pflag.FlagSet) { utils.CheckError("Error in getting tolerance: ", err) collectionInput := types.CreateCollectionInput{ - Address: address, - Password: password, Aggregation: aggregation, Power: power, JobIds: jobIdInUint, Tolerance: tolerance, + Account: account, } - log.Debugf("ExecuteUpdateCollection: Calling UpdateCollection() with arguments collectionInput: %+v, collectionId: %d", collectionInput, collectionId) txn, err := cmdUtils.UpdateCollection(client, config, collectionInput, collectionId) utils.CheckError("Update Collection error: ", err) err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) @@ -96,14 +100,13 @@ func (*UtilsStruct) UpdateCollection(client *ethclient.Client, config types.Conf } txnOpts := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: collectionInput.Password, - AccountAddress: collectionInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.CollectionManagerAddress, MethodName: "updateCollection", Parameters: []interface{}{collectionId, collectionInput.Tolerance, collectionInput.Aggregation, collectionInput.Power, jobIds}, ABI: bindings.CollectionManagerMetaData.ABI, + Account: collectionInput.Account, }) log.Info("Updating collection...") log.Debugf("Executing UpdateCollection transaction with collectionId = %d, tolerance = %d, aggregation method = %d, power = %d, jobIds = %v", collectionId, collectionInput.Tolerance, collectionInput.Aggregation, collectionInput.Power, jobIds) diff --git a/cmd/updateCollection_test.go b/cmd/updateCollection_test.go index 692656719..f473a212d 100644 --- a/cmd/updateCollection_test.go +++ b/cmd/updateCollection_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -19,9 +16,6 @@ import ( ) func TestUpdateCollection(t *testing.T) { - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - var client *ethclient.Client var config types.Configurations var WaitIfCommitStateStatus uint32 @@ -30,7 +24,6 @@ func TestUpdateCollection(t *testing.T) { var collectionId uint16 type args struct { - txnOpts *bind.TransactOpts updateCollectionTxn *Types.Transaction updateCollectionErr error waitIfCommitStateErr error @@ -46,7 +39,6 @@ func TestUpdateCollection(t *testing.T) { { name: "Test 1: When UpdateCollection function executes successfully", args: args{ - txnOpts: txnOpts, updateCollectionTxn: &Types.Transaction{}, updateCollectionErr: nil, waitIfCommitStateErr: nil, @@ -58,7 +50,6 @@ func TestUpdateCollection(t *testing.T) { { name: "Test 2: When updateCollection transaction fails", args: args{ - txnOpts: txnOpts, updateCollectionTxn: &Types.Transaction{}, updateCollectionErr: errors.New("updateCollection error"), waitIfCommitStateErr: nil, @@ -70,7 +61,6 @@ func TestUpdateCollection(t *testing.T) { { name: "Test 3: When there is an error in WaitIfConfirmState", args: args{ - txnOpts: txnOpts, updateCollectionTxn: &Types.Transaction{}, updateCollectionErr: nil, waitIfCommitStateErr: errors.New("waitIfCommitState error"), @@ -85,7 +75,7 @@ func TestUpdateCollection(t *testing.T) { SetUpMockInterfaces() utilsMock.On("ConvertUintArrayToUint16Array", mock.Anything).Return(jobIdUint16) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("WaitIfCommitState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(WaitIfCommitStateStatus, tt.args.waitIfCommitStateErr) assetManagerMock.On("UpdateCollection", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.updateCollectionTxn, tt.args.updateCollectionErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -287,7 +277,8 @@ func TestExecuteUpdateCollection(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetUint16CollectionId", flagSet).Return(tt.args.collectionId, tt.args.collectionIdErr) flagSetMock.On("GetUintSliceJobIds", flagSet).Return(tt.args.jobId, tt.args.jobIdErr) diff --git a/cmd/updateCommission.go b/cmd/updateCommission.go index b5fa8e3d9..5c59f15ab 100644 --- a/cmd/updateCommission.go +++ b/cmd/updateCommission.go @@ -3,6 +3,7 @@ package cmd import ( "errors" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -49,7 +50,12 @@ func (*UtilsStruct) ExecuteUpdateCommission(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) commission, err := flagSetUtils.GetUint8Commission(flagSet) @@ -60,14 +66,12 @@ func (*UtilsStruct) ExecuteUpdateCommission(flagSet *pflag.FlagSet) { updateCommissionInput := types.UpdateCommissionInput{ Commission: commission, - Address: address, - Password: password, StakerId: stakerId, + Account: account, } - log.Debugf("ExecuteUpdateCommission: calling UpdateCommission() with argument UpdateCommissionInput: %+v", updateCommissionInput) err = cmdUtils.UpdateCommission(config, client, updateCommissionInput) - utils.CheckError("SetDelegation error: ", err) + utils.CheckError("UpdateCommission error: ", err) } //This function allows a staker to add/update the commission value @@ -114,14 +118,13 @@ func (*UtilsStruct) UpdateCommission(config types.Configurations, client *ethcli } txnOpts := types.TransactionOptions{ Client: client, - Password: updateCommissionInput.Password, - AccountAddress: updateCommissionInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.StakeManagerAddress, ABI: bindings.StakeManagerMetaData.ABI, MethodName: "updateCommission", Parameters: []interface{}{updateCommissionInput.Commission}, + Account: updateCommissionInput.Account, } updateCommissionTxnOpts := razorUtils.GetTxnOpts(txnOpts) log.Infof("Setting the commission value of Staker %d to %d%%", updateCommissionInput.StakerId, updateCommissionInput.Commission) diff --git a/cmd/updateCommission_test.go b/cmd/updateCommission_test.go index 725670a26..79dbeecd7 100644 --- a/cmd/updateCommission_test.go +++ b/cmd/updateCommission_test.go @@ -1,16 +1,12 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" - "math/big" + "razor/accounts" "razor/core/types" "razor/pkg/bindings" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -27,9 +23,6 @@ func TestUpdateCommission(t *testing.T) { WaitTime: 1, } - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { commission uint8 stakerInfo bindings.StructsStaker @@ -198,7 +191,7 @@ func TestUpdateCommission(t *testing.T) { SetUpMockInterfaces() utilsMock.On("GetStaker", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.stakerInfo, tt.args.stakerInfoErr) - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.epochErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(nil) @@ -342,7 +335,8 @@ func TestExecuteUpdateCommission(t *testing.T) { cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetUint8Commission", flagSet).Return(tt.args.commission, tt.args.commissionErr) utilsMock.On("GetStakerId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) diff --git a/cmd/updateJob.go b/cmd/updateJob.go index e75e421f2..a1cbe352d 100644 --- a/cmd/updateJob.go +++ b/cmd/updateJob.go @@ -2,6 +2,7 @@ package cmd import ( + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -52,7 +53,12 @@ func (*UtilsStruct) ExecuteUpdateJob(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) jobId, err := flagSetUtils.GetUint16JobId(flagSet) @@ -74,16 +80,14 @@ func (*UtilsStruct) ExecuteUpdateJob(flagSet *pflag.FlagSet) { utils.CheckError("Error in getting selector type: ", err) jobInput := types.CreateJobInput{ - Address: address, - Password: password, Power: power, Selector: selector, Url: url, Weight: weight, SelectorType: selectorType, + Account: account, } - log.Debugf("ExecuteUpdateJob: Calling UpdateJob() with arguments jobInput = %+v, jobId = %d", jobInput, jobId) txn, err := cmdUtils.UpdateJob(client, config, jobInput, jobId) utils.CheckError("UpdateJob error: ", err) err = razorUtils.WaitForBlockCompletion(client, txn.Hex()) @@ -99,14 +103,13 @@ func (*UtilsStruct) UpdateJob(client *ethclient.Client, config types.Configurati } txnArgs := razorUtils.GetTxnOpts(types.TransactionOptions{ Client: client, - Password: jobInput.Password, - AccountAddress: jobInput.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.CollectionManagerAddress, MethodName: "updateJob", Parameters: []interface{}{jobId, jobInput.Weight, jobInput.Power, jobInput.SelectorType, jobInput.Selector, jobInput.Url}, ABI: bindings.CollectionManagerMetaData.ABI, + Account: jobInput.Account, }) log.Info("Updating Job...") log.Debugf("Executing UpdateJob transaction with arguments jobId = %d, weight = %d, power = %d, selector type = %d, selector = %s, URL = %s", jobId, jobInput.Weight, jobInput.Power, jobInput.SelectorType, jobInput.Selector, jobInput.Url) diff --git a/cmd/updateJob_test.go b/cmd/updateJob_test.go index 77f49f905..0d2da59a7 100644 --- a/cmd/updateJob_test.go +++ b/cmd/updateJob_test.go @@ -1,16 +1,13 @@ package cmd import ( - "crypto/ecdsa" - "crypto/rand" "errors" - "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core" "razor/core/types" "testing" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" @@ -26,11 +23,7 @@ func TestUpdateJob(t *testing.T) { var jobInput types.CreateJobInput var jobId uint16 - privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) - txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) - type args struct { - txnOpts *bind.TransactOpts updateJobTxn *Types.Transaction updateJobErr error waitIfCommitStateErr error @@ -45,7 +38,6 @@ func TestUpdateJob(t *testing.T) { { name: "Test 1: When UpdateJob function executes successfully", args: args{ - txnOpts: txnOpts, updateJobTxn: &Types.Transaction{}, hash: common.BigToHash(big.NewInt(1)), }, @@ -55,7 +47,6 @@ func TestUpdateJob(t *testing.T) { { name: "Test 2: When updateJob transaction fails", args: args{ - txnOpts: txnOpts, updateJobTxn: &Types.Transaction{}, updateJobErr: errors.New("updateJob error"), hash: common.BigToHash(big.NewInt(1)), @@ -66,7 +57,6 @@ func TestUpdateJob(t *testing.T) { { name: "Test 3: When there is an error in WaitIfConfirmState", args: args{ - txnOpts: txnOpts, updateJobTxn: &Types.Transaction{}, waitIfCommitStateErr: errors.New("waitIfCommitState error"), hash: common.BigToHash(big.NewInt(1)), @@ -79,7 +69,7 @@ func TestUpdateJob(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(txnOpts) + utilsMock.On("GetTxnOpts", mock.AnythingOfType("types.TransactionOptions")).Return(TxnOpts) cmdUtilsMock.On("WaitIfCommitState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(WaitIfCommitStateStatus, tt.args.waitIfCommitStateErr) assetManagerMock.On("UpdateJob", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("*bind.TransactOpts"), mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.updateJobTxn, tt.args.updateJobErr) transactionMock.On("Hash", mock.Anything).Return(tt.args.hash) @@ -334,7 +324,8 @@ func TestExecuteUpdateJob(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", flagSet).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetStringUrl", flagSet).Return(tt.args.url, tt.args.urlErr) flagSetMock.On("GetStringSelector", flagSet).Return(tt.args.selector, tt.args.selectorErr) diff --git a/cmd/vote.go b/cmd/vote.go index 10410b4b2..4aa85cdaa 100644 --- a/cmd/vote.go +++ b/cmd/vote.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "math/big" + "net/http" "os" "os/signal" "path/filepath" @@ -17,6 +18,8 @@ import ( "razor/utils" "time" + Types "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/spf13/pflag" @@ -62,7 +65,12 @@ func (*UtilsStruct) ExecuteVote(flagSet *pflag.FlagSet) { log.Debug("Getting password...") password := razorUtils.AssignPassword(flagSet) - err = razorUtils.CheckPassword(address, password) + accountManager, err := razorUtils.AccountManagerForKeystore() + utils.CheckError("Error in getting accounts manager for keystore: ", err) + + account := accounts.InitAccountStruct(address, password, accountManager) + + err = razorUtils.CheckPassword(account) utils.CheckError("Error in fetching private key from given password: ", err) isRogue, err := flagSetUtils.GetBoolRogue(flagSet) @@ -86,11 +94,35 @@ func (*UtilsStruct) ExecuteVote(flagSet *pflag.FlagSet) { log.Warn("YOU ARE RUNNING VOTE IN ROGUE MODE, THIS CAN INCUR PENALTIES!") } - account := types.Account{Address: address, Password: password} + httpClient := &http.Client{ + Timeout: time.Duration(config.HTTPTimeout) * time.Second, + Transport: &http.Transport{ + MaxIdleConns: core.HTTPClientMaxIdleConns, + MaxIdleConnsPerHost: core.HTTPClientMaxIdleConnsPerHost, + }, + } + + stakerId, err := razorUtils.GetStakerId(client, address) + utils.CheckError("Error in getting staker id: ", err) + + if stakerId == 0 { + log.Fatal("Staker doesn't exist") + } cmdUtils.HandleExit() + + jobsCache, collectionsCache, initCacheBlockNumber, err := cmdUtils.InitJobAndCollectionCache(client) + utils.CheckError("Error in initializing asset cache: ", err) + + commitParams := &types.CommitParams{ + JobsCache: jobsCache, + CollectionsCache: collectionsCache, + HttpClient: httpClient, + FromBlockToCheckForEvents: initCacheBlockNumber, + } + log.Debugf("Calling Vote() with arguments rogueData = %+v, account address = %s, backup node actions to ignore = %s", rogueData, account.Address, backupNodeActionsToIgnore) - if err := cmdUtils.Vote(context.Background(), config, client, rogueData, account, backupNodeActionsToIgnore); err != nil { + if err := cmdUtils.Vote(context.Background(), config, client, account, stakerId, commitParams, rogueData, backupNodeActionsToIgnore); err != nil { log.Errorf("%v\n", err) osUtils.Exit(1) } @@ -121,7 +153,7 @@ func (*UtilsStruct) HandleExit() { } //This function handles all the states of voting -func (*UtilsStruct) Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, rogueData types.Rogue, account types.Account, backupNodeActionsToIgnore []string) error { +func (*UtilsStruct) Vote(ctx context.Context, config types.Configurations, client *ethclient.Client, account types.Account, stakerId uint32, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) error { header, err := clientUtils.GetLatestBlockWithRetry(client) utils.CheckError("Error in getting block: ", err) for { @@ -138,7 +170,7 @@ func (*UtilsStruct) Vote(ctx context.Context, config types.Configurations, clien log.Debugf("Vote: Latest header value: %d", latestHeader.Number) if latestHeader.Number.Cmp(header.Number) != 0 { header = latestHeader - cmdUtils.HandleBlock(client, account, latestHeader.Number, config, rogueData, backupNodeActionsToIgnore) + cmdUtils.HandleBlock(client, account, stakerId, latestHeader, config, commitParams, rogueData, backupNodeActionsToIgnore) } time.Sleep(time.Second * time.Duration(core.BlockNumberInterval)) } @@ -153,8 +185,8 @@ var ( ) //This function handles the block -func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, blockNumber *big.Int, config types.Configurations, rogueData types.Rogue, backupNodeActionsToIgnore []string) { - state, err := razorUtils.GetBufferedState(client, config.BufferPercent) +func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, stakerId uint32, latestHeader *Types.Header, config types.Configurations, commitParams *types.CommitParams, rogueData types.Rogue, backupNodeActionsToIgnore []string) { + state, err := razorUtils.GetBufferedState(client, latestHeader, config.BufferPercent) if err != nil { log.Error("Error in getting state: ", err) return @@ -165,15 +197,6 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, return } - stakerId, err := razorUtils.GetStakerId(client, account.Address) - if err != nil { - log.Error("Error in getting staker id: ", err) - return - } - if stakerId == 0 { - log.Error("Staker doesn't exist") - return - } staker, err := razorUtils.GetStaker(client, stakerId) if err != nil { log.Error(err) @@ -230,21 +253,21 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, switch state { case 0: log.Debugf("Starting commit...") - err := cmdUtils.InitiateCommit(client, config, account, epoch, stakerId, rogueData) + err := cmdUtils.InitiateCommit(client, config, account, epoch, stakerId, latestHeader, commitParams, rogueData) if err != nil { log.Error(err) break } case 1: log.Debugf("Starting reveal...") - err := cmdUtils.InitiateReveal(client, config, account, epoch, staker, rogueData) + err := cmdUtils.InitiateReveal(client, config, account, epoch, staker, latestHeader, rogueData) if err != nil { log.Error(err) break } case 2: log.Debugf("Starting propose...") - err := cmdUtils.InitiatePropose(client, config, account, epoch, staker, blockNumber, rogueData) + err := cmdUtils.InitiatePropose(client, config, account, epoch, staker, latestHeader, rogueData) if err != nil { log.Error(err) break @@ -257,7 +280,7 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, break } - err := cmdUtils.HandleDispute(client, config, account, epoch, blockNumber, rogueData, backupNodeActionsToIgnore) + err := cmdUtils.HandleDispute(client, config, account, epoch, latestHeader.Number, rogueData, backupNodeActionsToIgnore) if err != nil { log.Error(err) break @@ -280,13 +303,12 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, if lastVerification == epoch && blockConfirmed < epoch { txn, err := cmdUtils.ClaimBlockReward(types.TransactionOptions{ Client: client, - Password: account.Password, - AccountAddress: account.Address, ChainId: core.ChainId, Config: config, ContractAddress: core.BlockManagerAddress, MethodName: "claimBlockReward", ABI: bindings.BlockManagerMetaData.ABI, + Account: account, }) if err != nil { @@ -303,8 +325,8 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, } } case -1: - if config.WaitTime > 5 { - timeUtils.Sleep(5 * time.Second) + if config.WaitTime >= core.BufferStateSleepTime { + timeUtils.Sleep(time.Second * time.Duration(core.BufferStateSleepTime)) return } } @@ -313,7 +335,24 @@ func (*UtilsStruct) HandleBlock(client *ethclient.Client, account types.Account, } //This function initiates the commit -func (*UtilsStruct) InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, rogueData types.Rogue) error { +func (*UtilsStruct) InitiateCommit(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, stakerId uint32, latestHeader *Types.Header, commitParams *types.CommitParams, rogueData types.Rogue) error { + lastCommit, err := razorUtils.GetEpochLastCommitted(client, stakerId) + if err != nil { + return errors.New("Error in fetching last commit: " + err.Error()) + } + log.Debug("InitiateCommit: Epoch last committed: ", lastCommit) + + if lastCommit >= epoch { + log.Debugf("Cannot commit in epoch %d because last committed epoch is %d", epoch, lastCommit) + return nil + } + + err = CheckForJobAndCollectionEvents(client, commitParams) + if err != nil { + log.Error("Error in checking for asset events: ", err) + return err + } + staker, err := razorUtils.GetStaker(client, stakerId) if err != nil { log.Error(err) @@ -331,17 +370,6 @@ func (*UtilsStruct) InitiateCommit(client *ethclient.Client, config types.Config log.Error("Stake is below minimum required. Kindly add stake to continue voting.") return nil } - - lastCommit, err := razorUtils.GetEpochLastCommitted(client, stakerId) - if err != nil { - return errors.New("Error in fetching last commit: " + err.Error()) - } - log.Debug("InitiateCommit: Epoch last committed: ", lastCommit) - - if lastCommit >= epoch { - log.Debugf("Cannot commit in epoch %d because last committed epoch is %d", epoch, lastCommit) - return nil - } razorPath, err := pathUtils.GetDefaultPath() if err != nil { return err @@ -356,13 +384,13 @@ func (*UtilsStruct) InitiateCommit(client *ethclient.Client, config types.Config } log.Debugf("InitiateCommit: Calling HandleCommitState with arguments epoch = %d, seed = %v, rogueData = %+v", epoch, seed, rogueData) - commitData, err := cmdUtils.HandleCommitState(client, epoch, seed, rogueData) + commitData, err := cmdUtils.HandleCommitState(client, epoch, seed, commitParams, rogueData) if err != nil { return errors.New("Error in getting active assets: " + err.Error()) } log.Debug("InitiateCommit: Commit Data: ", commitData) - commitTxn, err := cmdUtils.Commit(client, config, account, epoch, seed, commitData.Leaves) + commitTxn, err := cmdUtils.Commit(client, config, account, epoch, latestHeader, seed, commitData.Leaves) if err != nil { return errors.New("Error in committing data: " + err.Error()) } @@ -395,7 +423,7 @@ func (*UtilsStruct) InitiateCommit(client *ethclient.Client, config types.Config } //This function initiates the reveal -func (*UtilsStruct) InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, rogueData types.Rogue) error { +func (*UtilsStruct) InitiateReveal(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *Types.Header, rogueData types.Rogue) error { stakedAmount := staker.Stake log.Debug("InitiateReveal: Staked Amount: ", stakedAmount) minStakeAmount, err := razorUtils.GetMinStakeAmount(client) @@ -506,14 +534,14 @@ func (*UtilsStruct) InitiateReveal(client *ethclient.Client, config types.Config SeqAllottedCollections: globalCommitDataStruct.SeqAllottedCollections, } log.Debugf("InitiateReveal: Calling Reveal() with arguments epoch = %d, commitDataToSend = %+v, signature = %v", epoch, commitDataToSend, signature) - revealTxn, err := cmdUtils.Reveal(client, config, account, epoch, commitDataToSend, signature) + revealTxn, err := cmdUtils.Reveal(client, config, account, epoch, latestHeader, commitDataToSend, signature) if err != nil { return errors.New("Reveal error: " + err.Error()) } if revealTxn != core.NilHash { waitForBlockCompletionErr := razorUtils.WaitForBlockCompletion(client, revealTxn.Hex()) if waitForBlockCompletionErr != nil { - log.Error("Error in WaitForBlockCompletionErr for reveal: ", err) + log.Error("Error in WaitForBlockCompletionErr for reveal: ", waitForBlockCompletionErr) return err } } @@ -525,7 +553,7 @@ func (*UtilsStruct) InitiateReveal(client *ethclient.Client, config types.Config } //This function initiates the propose -func (*UtilsStruct) InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, blockNumber *big.Int, rogueData types.Rogue) error { +func (*UtilsStruct) InitiatePropose(client *ethclient.Client, config types.Configurations, account types.Account, epoch uint32, staker bindings.StructsStaker, latestHeader *Types.Header, rogueData types.Rogue) error { stakedAmount := staker.Stake log.Debug("InitiatePropose: Staked Amount: ", stakedAmount) minStakeAmount, err := razorUtils.GetMinStakeAmount(client) @@ -557,8 +585,8 @@ func (*UtilsStruct) InitiatePropose(client *ethclient.Client, config types.Confi return nil } - log.Debugf("InitiatePropose: Calling Propose() with arguments staker = %+v, epoch = %d, blockNumber = %s, rogueData = %+v", staker, epoch, blockNumber, rogueData) - err = cmdUtils.Propose(client, config, account, staker, epoch, blockNumber, rogueData) + log.Debugf("InitiatePropose: Calling Propose() with arguments staker = %+v, epoch = %d, blockNumber = %s, rogueData = %+v", staker, epoch, latestHeader.Number, rogueData) + err = cmdUtils.Propose(client, config, account, staker, epoch, latestHeader, rogueData) if err != nil { return errors.New("Propose error: " + err.Error()) } @@ -575,7 +603,7 @@ func (*UtilsStruct) CalculateSecret(account types.Account, epoch uint32, keystor ethHash := utils.SignHash(hash) log.Debug("Hash generated for secret") log.Debug("CalculateSecret: Ethereum signed hash: ", ethHash) - signedData, err := accounts.AccountUtilsInterface.SignData(ethHash, account, keystorePath) + signedData, err := account.AccountManager.SignData(ethHash, account.Address, account.Password) if err != nil { return nil, nil, errors.New("Error in signing the data: " + err.Error()) } diff --git a/cmd/vote_test.go b/cmd/vote_test.go index 4dadc4683..26e8d2186 100644 --- a/cmd/vote_test.go +++ b/cmd/vote_test.go @@ -1,17 +1,23 @@ package cmd import ( + "context" "encoding/hex" "errors" "math/big" "os" "path" "path/filepath" + "razor/accounts" + "razor/cache" "razor/core/types" "razor/pkg/bindings" "razor/utils" "reflect" "testing" + "time" + + Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" @@ -34,6 +40,8 @@ func TestExecuteVote(t *testing.T) { rogueModeErr error address string addressErr error + stakerId uint32 + stakerIdErr error voteErr error } tests := []struct { @@ -47,6 +55,7 @@ func TestExecuteVote(t *testing.T) { config: config, password: "test", address: "0x000000000000000000000000000000000000dea1", + stakerId: 1, rogueStatus: true, rogueMode: []string{"propose", "commit"}, voteErr: nil, @@ -109,12 +118,37 @@ func TestExecuteVote(t *testing.T) { config: config, password: "test", address: "0x000000000000000000000000000000000000dea1", + stakerId: 1, rogueStatus: true, rogueMode: []string{"propose", "commit"}, voteErr: errors.New("vote error"), }, expectedFatal: false, }, + { + name: "Test 7: When there is an error in getting stakerId", + args: args{ + config: config, + password: "test", + address: "0x000000000000000000000000000000000000dea1", + stakerIdErr: errors.New("stakerId error"), + rogueStatus: true, + rogueMode: []string{"propose", "commit"}, + }, + expectedFatal: true, + }, + { + name: "Test 8: When stakerId is 0", + args: args{ + config: config, + password: "test", + address: "0x000000000000000000000000000000000000dea1", + stakerId: 0, + rogueStatus: true, + rogueMode: []string{"propose", "commit"}, + }, + expectedFatal: true, + }, } defer func() { log.ExitFunc = nil }() @@ -128,14 +162,17 @@ func TestExecuteVote(t *testing.T) { fileUtilsMock.On("AssignLogFile", mock.AnythingOfType("*pflag.FlagSet"), mock.Anything) cmdUtilsMock.On("GetConfigData").Return(tt.args.config, tt.args.configErr) utilsMock.On("AssignPassword", flagSet).Return(tt.args.password) - utilsMock.On("CheckPassword", mock.Anything, mock.Anything).Return(nil) + utilsMock.On("CheckPassword", mock.Anything).Return(nil) + utilsMock.On("AccountManagerForKeystore").Return(&accounts.AccountManager{}, nil) flagSetMock.On("GetStringAddress", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.address, tt.args.addressErr) flagSetMock.On("GetStringSliceBackupNode", mock.Anything).Return([]string{}, nil) utilsMock.On("ConnectToClient", mock.AnythingOfType("string")).Return(client) flagSetMock.On("GetBoolRogue", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.rogueStatus, tt.args.rogueErr) + utilsMock.On("GetStakerId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.stakerId, tt.args.stakerIdErr) flagSetMock.On("GetStringSliceRogueMode", mock.AnythingOfType("*pflag.FlagSet")).Return(tt.args.rogueMode, tt.args.rogueModeErr) + cmdUtilsMock.On("InitJobAndCollectionCache", mock.Anything).Return(&cache.JobsCache{}, &cache.CollectionsCache{}, big.NewInt(100), nil) cmdUtilsMock.On("HandleExit").Return() - cmdUtilsMock.On("Vote", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.voteErr) + cmdUtilsMock.On("Vote", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.voteErr) osMock.On("Exit", mock.AnythingOfType("int")).Return() utils := &UtilsStruct{} @@ -275,33 +312,43 @@ func TestCalculateSecret(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { InitializeInterfaces() - gotSignature, gotSecret, err := cmdUtils.CalculateSecret(types.Account{Address: tt.args.address, - Password: tt.args.password}, tt.args.epoch, testKeystorePath, tt.args.chainId) + accountManager := accounts.NewAccountManager(testKeystorePath) + account := accounts.InitAccountStruct(tt.args.address, tt.args.password, accountManager) + gotSignature, gotSecret, err := cmdUtils.CalculateSecret(account, tt.args.epoch, testKeystorePath, tt.args.chainId) + if (err != nil) != tt.wantErr { + t.Errorf("CalculateSecret() error = %v, wantErr %v", err, tt.wantErr) + return + } gotSignatureInHash := hex.EncodeToString(gotSignature) gotSecretInHash := hex.EncodeToString(gotSecret) + if !reflect.DeepEqual(gotSignatureInHash, tt.wantSignature) { t.Errorf("CalculateSecret() Signature = %v, want %v", gotSignatureInHash, tt.wantSignature) } if !reflect.DeepEqual(gotSecretInHash, tt.wantSecret) { t.Errorf("CalculateSecret() Secret = %v, want %v", gotSecretInHash, tt.wantSecret) } - if (err != nil) != tt.wantErr { - t.Errorf("CalculateSecret() error = %v, wantErr %v", err, tt.wantErr) - return - } }) } } func TestInitiateCommit(t *testing.T) { var ( - client *ethclient.Client - config types.Configurations - account types.Account - stakerId uint32 - rogueData types.Rogue + client *ethclient.Client + config types.Configurations + latestHeader *Types.Header + account types.Account + stakerId uint32 + rogueData types.Rogue ) + + commitParams := &types.CommitParams{ + JobsCache: cache.NewJobsCache(), + CollectionsCache: cache.NewCollectionsCache(), + FromBlockToCheckForEvents: big.NewInt(1), + } + type args struct { staker bindings.StructsStaker stakerErr error @@ -354,13 +401,17 @@ func TestInitiateCommit(t *testing.T) { { name: "Test 2: When there is an error in getting staker", args: args{ - stakerErr: errors.New("error in getting staker"), + epoch: 5, + lastCommit: 2, + stakerErr: errors.New("error in getting staker"), }, wantErr: true, }, { name: "Test 3: When there is an error in getting minStakeAmount", args: args{ + epoch: 5, + lastCommit: 2, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, minStakeAmountErr: errors.New("error in getting minStakeAmount"), }, @@ -501,14 +552,16 @@ func TestInitiateCommit(t *testing.T) { utilsMock.On("GetEpochLastCommitted", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.lastCommit, tt.args.lastCommitErr) cmdUtilsMock.On("CalculateSecret", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.signature, tt.args.secret, tt.args.secretErr) cmdUtilsMock.On("GetSalt", mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(tt.args.salt, tt.args.saltErr) - cmdUtilsMock.On("HandleCommitState", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.commitData, tt.args.commitDataErr) + cmdUtilsMock.On("HandleCommitState", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.commitData, tt.args.commitDataErr) pathMock.On("GetDefaultPath").Return(tt.args.path, tt.args.pathErr) - cmdUtilsMock.On("Commit", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.commitTxn, tt.args.commitTxnErr) + cmdUtilsMock.On("Commit", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.commitTxn, tt.args.commitTxnErr) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.waitForBlockCompletionErr) pathMock.On("GetCommitDataFileName", mock.AnythingOfType("string")).Return(tt.args.fileName, tt.args.fileNameErr) fileUtilsMock.On("SaveDataToCommitJsonFile", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.saveErr) + clientUtilsMock.On("GetLatestBlockWithRetry", mock.Anything).Return(&Types.Header{Number: big.NewInt(100)}, nil) + clientUtilsMock.On("FilterLogsWithRetry", mock.Anything, mock.Anything).Return([]Types.Log{}, nil) ut := &UtilsStruct{} - if err := ut.InitiateCommit(client, config, account, tt.args.epoch, stakerId, rogueData); (err != nil) != tt.wantErr { + if err := ut.InitiateCommit(client, config, account, tt.args.epoch, stakerId, latestHeader, commitParams, rogueData); (err != nil) != tt.wantErr { t.Errorf("InitiateCommit() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -517,9 +570,10 @@ func TestInitiateCommit(t *testing.T) { func TestInitiateReveal(t *testing.T) { var ( - client *ethclient.Client - config types.Configurations - account types.Account + client *ethclient.Client + config types.Configurations + account types.Account + latestHeader *Types.Header ) randomNum := big.NewInt(1111) @@ -725,10 +779,10 @@ func TestInitiateReveal(t *testing.T) { cmdUtilsMock.On("CalculateSecret", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.signature, tt.args.secret, tt.args.secretErr) cmdUtilsMock.On("GetSalt", mock.Anything, mock.Anything).Return([32]byte{}, nil) utilsMock.On("GetCommitment", mock.Anything, mock.Anything).Return(types.Commitment{}, nil) - cmdUtilsMock.On("Reveal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.revealTxn, tt.args.revealTxnErr) + cmdUtilsMock.On("Reveal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.revealTxn, tt.args.revealTxnErr) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(nil) ut := &UtilsStruct{} - if err := ut.InitiateReveal(client, config, account, tt.args.epoch, tt.args.staker, tt.args.rogueData); (err != nil) != tt.wantErr { + if err := ut.InitiateReveal(client, config, account, tt.args.epoch, tt.args.staker, latestHeader, tt.args.rogueData); (err != nil) != tt.wantErr { t.Errorf("InitiateReveal() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -737,11 +791,10 @@ func TestInitiateReveal(t *testing.T) { func TestInitiatePropose(t *testing.T) { var ( - client *ethclient.Client - config types.Configurations - account types.Account - blockNumber *big.Int - rogueData types.Rogue + client *ethclient.Client + config types.Configurations + account types.Account + rogueData types.Rogue ) type args struct { staker bindings.StructsStaker @@ -754,6 +807,10 @@ func TestInitiatePropose(t *testing.T) { lastRevealErr error proposeTxnErr error } + + latestHeader := &Types.Header{ + Number: big.NewInt(1), + } tests := []struct { name string args args @@ -841,10 +898,10 @@ func TestInitiatePropose(t *testing.T) { utilsMock.On("GetMinStakeAmount", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.minStakeAmount, tt.args.minStakeAmountErr) utilsMock.On("GetEpochLastProposed", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.lastProposal, tt.args.lastProposalErr) utilsMock.On("GetEpochLastRevealed", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.lastReveal, tt.args.lastRevealErr) - cmdUtilsMock.On("Propose", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.proposeTxnErr) + cmdUtilsMock.On("Propose", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.proposeTxnErr) utilsMock.On("WaitForBlockCompletion", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(nil) ut := &UtilsStruct{} - if err := ut.InitiatePropose(client, config, account, tt.args.epoch, tt.args.staker, blockNumber, rogueData); (err != nil) != tt.wantErr { + if err := ut.InitiatePropose(client, config, account, tt.args.epoch, tt.args.staker, latestHeader, rogueData); (err != nil) != tt.wantErr { t.Errorf("InitiatePropose() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -855,11 +912,15 @@ func TestHandleBlock(t *testing.T) { var ( client *ethclient.Client account types.Account - blockNumber *big.Int + stakerId uint32 + commitParams *types.CommitParams rogueData types.Rogue backupNodeActionsToIgnore []string ) + latestHeader := &Types.Header{ + Number: big.NewInt(1001), + } type args struct { config types.Configurations state int64 @@ -867,8 +928,6 @@ func TestHandleBlock(t *testing.T) { epoch uint32 epochErr error stateName string - stakerId uint32 - stakerIdErr error staker bindings.StructsStaker stakerErr error ethBalance *big.Int @@ -899,7 +958,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -921,30 +979,12 @@ func TestHandleBlock(t *testing.T) { epochErr: errors.New("error in getting epoch"), }, }, - { - name: "Test 4: When there is an error in getting stakerId", - args: args{ - state: 0, - epoch: 1, - stakerIdErr: errors.New("error in getting stakerId"), - }, - }, - { - name: "Test 5: When stakerId is 0", - args: args{ - state: 0, - epoch: 1, - stateName: "commit", - stakerId: 0, - }, - }, { name: "Test 6: When there is an error in getting staker", args: args{ state: 0, epoch: 1, stateName: "commit", - stakerId: 1, stakerErr: errors.New("error in getting staker"), }, }, @@ -954,7 +994,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalanceErr: errors.New("error in getting ethBalance"), }, @@ -965,7 +1004,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, sRZRBalance: big.NewInt(1000), staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), @@ -979,7 +1017,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -993,7 +1030,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(100)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(100), @@ -1008,7 +1044,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(0)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(0), @@ -1023,7 +1058,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000), IsSlashed: true}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1038,7 +1072,6 @@ func TestHandleBlock(t *testing.T) { state: 0, epoch: 1, stateName: "commit", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1054,7 +1087,6 @@ func TestHandleBlock(t *testing.T) { state: 1, epoch: 1, stateName: "reveal", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1070,7 +1102,6 @@ func TestHandleBlock(t *testing.T) { state: 2, epoch: 1, stateName: "propose", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1086,7 +1117,6 @@ func TestHandleBlock(t *testing.T) { state: 3, epoch: 1, stateName: "dispute", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1102,7 +1132,6 @@ func TestHandleBlock(t *testing.T) { state: 3, epoch: 1, stateName: "dispute", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1118,7 +1147,6 @@ func TestHandleBlock(t *testing.T) { state: 3, epoch: 1, stateName: "dispute", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1136,7 +1164,6 @@ func TestHandleBlock(t *testing.T) { epoch: 1, stateName: "confirm", lastVerification: 1, - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1153,7 +1180,6 @@ func TestHandleBlock(t *testing.T) { epoch: 2, stateName: "confirm", lastVerification: 1, - stakerId: 2, staker: bindings.StructsStaker{Id: 2, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1170,7 +1196,6 @@ func TestHandleBlock(t *testing.T) { epoch: 1, lastVerification: 4, stateName: "dispute", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1186,7 +1211,6 @@ func TestHandleBlock(t *testing.T) { epoch: 1, lastVerification: 4, stateName: "", - stakerId: 1, staker: bindings.StructsStaker{Id: 1, Stake: big.NewInt(10000)}, ethBalance: big.NewInt(1000), actualStake: big.NewFloat(10000), @@ -1201,16 +1225,15 @@ func TestHandleBlock(t *testing.T) { t.Run(tt.name, func(t *testing.T) { SetUpMockInterfaces() - utilsMock.On("GetBufferedState", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("int32")).Return(tt.args.state, tt.args.stateErr) + utilsMock.On("GetBufferedState", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.state, tt.args.stateErr) utilsMock.On("GetEpoch", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.epoch, tt.args.epochErr) - utilsMock.On("GetStakerId", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("string")).Return(tt.args.stakerId, tt.args.stakerIdErr) utilsMock.On("GetStaker", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32")).Return(tt.args.staker, tt.args.stakerErr) clientUtilsMock.On("BalanceAtWithRetry", mock.AnythingOfType("*ethclient.Client"), mock.Anything).Return(tt.args.ethBalance, tt.args.ethBalanceErr) utilsMock.On("GetStakerSRZRBalance", mock.Anything, mock.Anything).Return(tt.args.sRZRBalance, tt.args.sRZRBalanceErr) osMock.On("Exit", mock.AnythingOfType("int")).Return() - cmdUtilsMock.On("InitiateCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateCommitErr) - cmdUtilsMock.On("InitiateReveal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateRevealErr) - cmdUtilsMock.On("InitiatePropose", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateProposeErr) + cmdUtilsMock.On("InitiateCommit", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateCommitErr) + cmdUtilsMock.On("InitiateReveal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateRevealErr) + cmdUtilsMock.On("InitiatePropose", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.initiateProposeErr) cmdUtilsMock.On("HandleDispute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.handleDisputeErr) utilsMock.On("IsFlagPassed", mock.AnythingOfType("string")).Return(tt.args.isFlagPassed) cmdUtilsMock.On("HandleClaimBounty", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.handleClaimBountyErr) @@ -1220,7 +1243,67 @@ func TestHandleBlock(t *testing.T) { utilsMock.On("WaitTillNextNSecs", mock.AnythingOfType("int32")).Return() lastVerification = tt.args.lastVerification ut := &UtilsStruct{} - ut.HandleBlock(client, account, blockNumber, tt.args.config, rogueData, backupNodeActionsToIgnore) + ut.HandleBlock(client, account, stakerId, latestHeader, tt.args.config, commitParams, rogueData, backupNodeActionsToIgnore) + }) + } +} + +func TestVote(t *testing.T) { + var ( + config types.Configurations + client *ethclient.Client + rogueData types.Rogue + account types.Account + stakerId uint32 + commitParams *types.CommitParams + backupNodeActionsToIgnore []string + ) + type args struct { + header *Types.Header + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Test when context is cancelled", + args: args{ + header: &Types.Header{ + Number: big.NewInt(101), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + SetUpMockInterfaces() + + clientUtilsMock.On("GetLatestBlockWithRetry", mock.Anything).Return(tt.args.header, nil) + cmdUtilsMock.On("HandleBlock", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + ut := &UtilsStruct{} + errChan := make(chan error) + // Run Vote function in a goroutine + go func() { + errChan <- ut.Vote(ctx, config, client, account, stakerId, commitParams, rogueData, backupNodeActionsToIgnore) + }() + + // Wait for some time to allow Vote function to execute + time.Sleep(time.Second * 2) + + // Cancel the context to simulate its done + cancel() + + // Check the error returned from the function + err := <-errChan + if (err != nil) != tt.wantErr { + t.Errorf("Vote() error = %v, wantErr %v", err, tt.wantErr) + } }) } } diff --git a/config.sh b/config.sh index c5b6e40a4..fc353a4c1 100644 --- a/config.sh +++ b/config.sh @@ -25,9 +25,9 @@ then BUFFER=20 fi -read -rp "Wait Time: (1) " WAIT_TIME +read -rp "Wait Time: (5) " WAIT_TIME if [ -z "$WAIT_TIME" ]; then - WAIT_TIME=1 + WAIT_TIME=5 fi read -rp "Gas Price: (1) " GAS_PRICE @@ -45,14 +45,14 @@ if [ -z "$MAX_SIZE" ]; then MAX_SIZE=200 fi -read -rp "Log Files Max Backups: (52) " MAX_BACKUPS +read -rp "Log Files Max Backups: (10) " MAX_BACKUPS if [ -z "$MAX_BACKUPS" ]; then - MAX_BACKUPS=52 + MAX_BACKUPS=10 fi -read -rp "Log Files Max Age: (365) " MAX_AGE +read -rp "Log Files Max Age: (60) " MAX_AGE if [ -z "$MAX_AGE" ]; then - MAX_AGE=365 + MAX_AGE=60 fi ALT_PROVIDER_OPTION="" diff --git a/core/constants.go b/core/constants.go index fb7d069aa..d895a7227 100644 --- a/core/constants.go +++ b/core/constants.go @@ -3,72 +3,111 @@ package core import ( - "github.com/ethereum/go-ethereum/common" "math/big" + + "github.com/ethereum/go-ethereum/common" ) -var EpochLength uint64 = 1200 -var NumberOfStates uint64 = 5 -var StateLength = EpochLength / NumberOfStates +const ( + EpochLength uint64 = 1200 + NumberOfStates uint64 = 5 + StateLength = EpochLength / NumberOfStates +) // ChainId corresponds to the SKALE chain -var ChainId = big.NewInt(0x5a79c44e) +var ChainId = big.NewInt(0x109B4597) + +const MaxRetries uint = 8 -var MaxRetries uint = 8 var NilHash = common.Hash{0x00} -var BlockCompletionTimeout = 30 + +const BlockCompletionTimeout = 30 //Following are the default config values for all the config parameters +const ( + DefaultGasMultiplier float32 = 1.0 + DefaultBufferPercent int32 = 20 + DefaultGasPrice int32 = 1 + DefaultWaitTime int32 = 5 + DefaultGasLimit float32 = 2 + DefaultGasLimitOverride uint64 = 30000000 + DefaultRPCTimeout int64 = 10 + DefaultHTTPTimeout int64 = 10 + DefaultLogLevel = "" +) -var DefaultGasMultiplier = 1.0 -var DefaultBufferPercent = 20 -var DefaultGasPrice = 1 -var DefaultWaitTime = 1 -var DefaultGasLimit = 2 -var DefaultGasLimitOverride = 50000000 -var DefaultRPCTimeout = 10 -var DefaultHTTPTimeout = 10 -var DefaultLogLevel = "" +//BufferStateSleepTime is the sleeping time whenever buffer state hits +const BufferStateSleepTime int32 = 5 //Following are the default logFile parameters in config - -var DefaultLogFileMaxSize = 200 -var DefaultLogFileMaxBackups = 52 -var DefaultLogFileMaxAge = 365 +const ( + DefaultLogFileMaxSize = 200 + DefaultLogFileMaxBackups = 10 + DefaultLogFileMaxAge = 60 +) //DisputeGasMultiplier is a constant gasLimitMultiplier to increase gas Limit for function `disputeCollectionIdShouldBeAbsent` and `disputeCollectionIdShouldBePresent` -var DisputeGasMultiplier float32 = 5.5 +const DisputeGasMultiplier float32 = 5.5 // Following are the constants which will be used to derive different file paths -var DataFileDirectory = "data_files" -var CommitDataFile = "_commitData.json" -var ProposeDataFile = "_proposeData.json" -var DisputeDataFile = "_disputeData.json" -var AssetsDataFile = "assets.json" -var ConfigFile = "razor.yaml" -var LogFileDirectory = "logs" -var DefaultPathName = ".razor" +const ( + DataFileDirectory = "data_files" + CommitDataFile = "_commitData.json" + ProposeDataFile = "_proposeData.json" + DisputeDataFile = "_disputeData.json" + AssetsDataFile = "assets.json" + ConfigFile = "razor.yaml" + LogFileDirectory = "logs" + DefaultPathName = ".razor" +) //BlockNumberInterval is the interval in seconds after which blockNumber needs to be calculated again -var BlockNumberInterval = 5 +const BlockNumberInterval = 5 //APIKeyRegex will be used as a regular expression to be matched in job Urls -var APIKeyRegex = `\$\{(.+?)\}` +const APIKeyRegex = `\$\{(.+?)\}` // Following are the constants which defines retry attempts and retry delay if there is an error in processing request - -var ProcessRequestRetryAttempts uint = 2 -var ProcessRequestRetryDelay = 2 +const ( + ProcessRequestRetryAttempts uint = 2 + ProcessRequestRetryDelay int64 = 2 +) //SwitchClientDuration is the time after which alternate client from secondary RPC will be switched back to client from primary RPC -var SwitchClientDuration = 5 * EpochLength +const SwitchClientDuration = 5 * EpochLength + +const ( + // HexReturnType is the ReturnType for a job if that job returns a hex value + HexReturnType = "hex" + + // HexArrayReturnType is the ReturnType for a job if that job returns a hex array value + HexArrayReturnType = "^hexArray\\[\\d+\\]$" + + // HexArrayExtractIndexRegex will be used as a regular expression to extract index from hexArray return type + HexArrayExtractIndexRegex = `^hexArray\[(\d+)\]$` +) -// HexReturnType is the ReturnType for a job if that job returns a hex value -var HexReturnType = "hex" +// Following are the constants which helps in calculating iteration for a staker +const ( + BatchSize = 1000 + NumRoutines = 10 + MaxIterations = 10000000 +) -// HexArrayReturnType is the ReturnType for a job if that job returns a hex array value -var HexArrayReturnType = "^hexArray\\[\\d+\\]$" +// Following are the constants used in custom http.Transport configuration for the common HTTP client that we use for all the requests +const ( + HTTPClientMaxIdleConns = 15 + HTTPClientMaxIdleConnsPerHost = 5 +) -// HexArrayExtractIndexRegex will be used as a regular expression to extract index from hexArray return type -var HexArrayExtractIndexRegex = `^hexArray\[(\d+)\]$` +const GetStakeSnapshotMethod = "getStakeSnapshot" + +// Following are the event names that nodes will listen to in order to update the jobs/collections in the cache +const ( + JobCreatedEvent = "JobCreated" + CollectionCreatedEvent = "CollectionCreated" + JobUpdatedEvent = "JobUpdated" + CollectionUpdatedEvent = "CollectionUpdated" + CollectionActivityStatusEvent = "CollectionActivityStatus" +) diff --git a/core/contracts.go b/core/contracts.go index f6650e983..6079b6e4c 100644 --- a/core/contracts.go +++ b/core/contracts.go @@ -1,7 +1,7 @@ package core -var StakeManagerAddress = "0x9f55a2C6C1F1Be8B01562cEae2df2F22931C7a46" -var RAZORAddress = "0x4500E10fEb89e46E9fb642D0c62b1a761278155D" -var CollectionManagerAddress = "0x3b76eB8c0282dAf531D7C507E4f3143A9A9c38b1" -var VoteManagerAddress = "0x11995b74D6d07a6Edc05653a71F3e8B3354caBF0" -var BlockManagerAddress = "0x096e44B0d8b68376C8Efe40F28C3857951f03069" +var StakeManagerAddress = "0xe0bC695203d9C9f379bcdE9260B9F71B64B85298" +var RAZORAddress = "0xcbf70914Fae03B3acB91E953De60CfDAaCA8145f" +var CollectionManagerAddress = "0x367962d1462C568A0dDd0e2448311469451bF5a3" +var VoteManagerAddress = "0x641BAD0641eB5B94B19568C0a22a55AEbDAF1870" +var BlockManagerAddress = "0x11aB70d78f1Dd2c3F967180d8A64858Db03A0aBa" diff --git a/core/types/account.go b/core/types/account.go index 5d230f53d..da16d30d8 100644 --- a/core/types/account.go +++ b/core/types/account.go @@ -1,6 +1,21 @@ package types +import ( + "crypto/ecdsa" + "github.com/ethereum/go-ethereum/accounts" +) + +//go:generate mockery --name=AccountManagerInterface --output=../../accounts/mocks --case=underscore + type Account struct { - Address string - Password string + Address string + Password string + AccountManager AccountManagerInterface +} + +type AccountManagerInterface interface { + CreateAccount(keystorePath, password string) accounts.Account + GetPrivateKey(address, password string) (*ecdsa.PrivateKey, error) + SignData(hash []byte, address string, password string) ([]byte, error) + NewAccount(passphrase string) (accounts.Account, error) } diff --git a/core/types/assets.go b/core/types/assets.go index 6f8ee01f2..ac5a7166f 100644 --- a/core/types/assets.go +++ b/core/types/assets.go @@ -64,3 +64,8 @@ type DataSourceURL struct { Header map[string]string `json:"header"` ReturnType string `json:"returnType"` } + +type CollectionResult struct { + Index int + Leaf *big.Int +} diff --git a/core/types/configurations.go b/core/types/configurations.go index a9864fb56..da42176f9 100644 --- a/core/types/configurations.go +++ b/core/types/configurations.go @@ -16,3 +16,9 @@ type Configurations struct { LogFileMaxBackups int LogFileMaxAge int } + +type ConfigDetail struct { + FlagName string + Key string + DefaultValue interface{} +} diff --git a/core/types/flag.go b/core/types/flag.go new file mode 100644 index 000000000..28cfa9be4 --- /dev/null +++ b/core/types/flag.go @@ -0,0 +1,11 @@ +package types + +type FlagDetail struct { + Name string + Type string +} + +type FlagValue struct { + Value interface{} + DefaultNilValue interface{} +} diff --git a/core/types/inputs.go b/core/types/inputs.go index 63893e579..89a95d18d 100644 --- a/core/types/inputs.go +++ b/core/types/inputs.go @@ -3,29 +3,25 @@ package types import "math/big" type UnstakeInput struct { - Address string - Password string + Account Account ValueInWei *big.Int StakerId uint32 } type RedeemBountyInput struct { - Address string - Password string + Account Account BountyId uint32 } type TransferInput struct { - FromAddress string - ToAddress string - Password string - ValueInWei *big.Int - Balance *big.Int + Account Account + ToAddress string + ValueInWei *big.Int + Balance *big.Int } type CreateJobInput struct { - Address string - Password string + Account Account Name string Url string Selector string @@ -35,9 +31,8 @@ type CreateJobInput struct { } type CreateCollectionInput struct { - Address string + Account Account Name string - Password string Aggregation uint32 Power int8 JobIds []uint @@ -45,21 +40,18 @@ type CreateCollectionInput struct { } type ExtendLockInput struct { - Address string - Password string + Account Account StakerId uint32 } type ModifyCollectionInput struct { - Address string - Password string + Account Account CollectionId uint16 Status bool } type SetDelegationInput struct { - Address string - Password string + Account Account Status bool StatusString string StakerId uint32 @@ -67,8 +59,7 @@ type SetDelegationInput struct { } type UpdateCommissionInput struct { - Address string - Password string + Account Account Commission uint8 StakerId uint32 } diff --git a/core/types/transaction.go b/core/types/transaction.go index 11c46cadd..75819cacf 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -7,14 +7,13 @@ import ( type TransactionOptions struct { Client *ethclient.Client - Password string EtherValue *big.Int Amount *big.Int - AccountAddress string ChainId *big.Int Config Configurations ContractAddress string MethodName string Parameters []interface{} ABI string + Account Account } diff --git a/core/types/vote.go b/core/types/vote.go index 445afb5ae..f1fb7b743 100644 --- a/core/types/vote.go +++ b/core/types/vote.go @@ -1,6 +1,10 @@ package types -import "math/big" +import ( + "math/big" + "net/http" + "razor/cache" +) type ElectedProposer struct { Iteration int @@ -62,3 +66,11 @@ type ProposeFileData struct { RevealedCollectionIds []uint16 RevealedDataMaps *RevealedDataMaps } + +type CommitParams struct { + JobsCache *cache.JobsCache + CollectionsCache *cache.CollectionsCache + LocalCache *cache.LocalCache + HttpClient *http.Client + FromBlockToCheckForEvents *big.Int +} diff --git a/core/version.go b/core/version.go index 834f4869d..b31b42ff3 100644 --- a/core/version.go +++ b/core/version.go @@ -4,7 +4,7 @@ import "fmt" const ( VersionMajor = 1 // Major version component of the current release - VersionMinor = 1 // Minor version component of the current release + VersionMinor = 2 // Minor version component of the current release VersionPatch = 0 // Patch version component of the current release VersionMeta = "" // Version metadata to append to the version string ) diff --git a/package.json b/package.json index a14ebf922..2fd2ffcc2 100644 --- a/package.json +++ b/package.json @@ -7,6 +7,7 @@ "setup": "make setup", "build": "make build", "build-all": "make all", + "build-all-testnet": "make all-testnet", "build-noargs": "make build-noargs", "build-noargs-testnet": "make build-noargs-testnet", "test": "go test ./... -v" diff --git a/update-chainId.sh b/update-chainId.sh index 215053115..591a317d4 100644 --- a/update-chainId.sh +++ b/update-chainId.sh @@ -7,7 +7,7 @@ CHAINID="" if [[ "$NETWORK" == "mainnet" ]]; then CHAINID="0x109B4597" elif [[ "$NETWORK" == "testnet" ]]; then - CHAINID="0x5A79C44E" + CHAINID="0x561bf78b" else echo "Invalid network specified. Please choose 'mainnet' or 'testnet'." exit 1 diff --git a/utils/api.go b/utils/api.go index 0b5dc2c6a..a64eb1a41 100644 --- a/utils/api.go +++ b/utils/api.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "razor/cache" "razor/core" "regexp" "time" @@ -19,34 +18,30 @@ import ( "github.com/gocolly/colly" ) -func GetDataFromAPI(dataSourceURLStruct types.DataSourceURL, localCache *cache.LocalCache) ([]byte, error) { - client := http.Client{ - Timeout: time.Duration(HTTPTimeout) * time.Second, - } - +func GetDataFromAPI(commitParams *types.CommitParams, dataSourceURLStruct types.DataSourceURL) ([]byte, error) { cacheKey, err := generateCacheKey(dataSourceURLStruct.URL, dataSourceURLStruct.Body) if err != nil { log.Errorf("Error in generating cache key for API %s: %v", dataSourceURLStruct.URL, err) return nil, err } - cachedData, found := localCache.Read(cacheKey) + cachedData, found := commitParams.LocalCache.Read(cacheKey) if found { log.Debugf("Getting Data for URL %s from local cache...", dataSourceURLStruct.URL) return cachedData, nil } - response, err := makeAPIRequest(client, dataSourceURLStruct) + response, err := makeAPIRequest(commitParams.HttpClient, dataSourceURLStruct) if err != nil { return nil, err } // Storing the data into cache - localCache.Update(response, cacheKey, time.Now().Add(time.Second*time.Duration(core.StateLength)).Unix()) + commitParams.LocalCache.Update(response, cacheKey, time.Now().Add(time.Second*time.Duration(core.StateLength)).Unix()) return response, nil } -func makeAPIRequest(client http.Client, dataSourceURLStruct types.DataSourceURL) ([]byte, error) { +func makeAPIRequest(httpClient *http.Client, dataSourceURLStruct types.DataSourceURL) ([]byte, error) { var requestBody io.Reader // Using the broader io.Reader interface here switch dataSourceURLStruct.Type { @@ -68,7 +63,7 @@ func makeAPIRequest(client http.Client, dataSourceURLStruct types.DataSourceURL) var response []byte err := retry.Do( func() error { - responseBody, err := ProcessRequest(client, dataSourceURLStruct, requestBody) + responseBody, err := ProcessRequest(httpClient, dataSourceURLStruct, requestBody) if err != nil { log.Errorf("Error in processing %s request: %v", dataSourceURLStruct.Type, err) return err @@ -84,6 +79,28 @@ func makeAPIRequest(client http.Client, dataSourceURLStruct types.DataSourceURL) return response, nil } +func parseJSONData(parsedJSON interface{}, selector string) (interface{}, error) { + switch v := parsedJSON.(type) { + case map[string]interface{}: // Handling JSON object response case + return GetDataFromJSON(v, selector) + + case []interface{}: // Handling JSON array of objects response case + if len(v) > 0 { + // The first element from JSON array is fetched + if elem, ok := v[0].(map[string]interface{}); ok { + return GetDataFromJSON(elem, selector) + } + log.Error("Element in array is not a JSON object") + return nil, errors.New("element in array is not a JSON object") + } + log.Error("Empty JSON array") + return nil, errors.New("empty JSON array") + default: + log.Error("Unexpected JSON structure") + return nil, errors.New("unexpected JSON structure") + } +} + func GetDataFromJSON(jsonObject map[string]interface{}, selector string) (interface{}, error) { if selector[0] == '[' { selector = "$" + selector @@ -124,13 +141,13 @@ func addHeaderToRequest(request *http.Request, headerMap map[string]string) *htt return request } -func ProcessRequest(client http.Client, dataSourceURLStruct types.DataSourceURL, requestBody io.Reader) ([]byte, error) { +func ProcessRequest(httpClient *http.Client, dataSourceURLStruct types.DataSourceURL, requestBody io.Reader) ([]byte, error) { request, err := http.NewRequest(dataSourceURLStruct.Type, dataSourceURLStruct.URL, requestBody) if err != nil { return nil, err } requestWithHeader := addHeaderToRequest(request, dataSourceURLStruct.Header) - response, err := client.Do(requestWithHeader) + response, err := httpClient.Do(requestWithHeader) if err != nil { log.Errorf("Error sending %s request URL %s: %v", dataSourceURLStruct.Type, dataSourceURLStruct.URL, err) return nil, err diff --git a/utils/api_test.go b/utils/api_test.go index d64474806..0d5355237 100644 --- a/utils/api_test.go +++ b/utils/api_test.go @@ -2,11 +2,15 @@ package utils import ( "encoding/hex" + "encoding/json" + "net/http" "razor/cache" "razor/core/types" "reflect" "testing" "time" + + "github.com/stretchr/testify/assert" ) func getAPIByteArray(index int) []byte { @@ -29,9 +33,17 @@ func getAPIByteArray(index int) []byte { } func TestGetDataFromAPI(t *testing.T) { - //postRequestInput := `{"type": "POST","url": "https://staging-v3.skalenodes.com/v1/staging-aware-chief-gianfar","body": {"jsonrpc": "2.0","method": "eth_chainId","params": [],"id": 0},"header": {"content-type": "application/json"}}` + //postRequestInput := `{"type": "POST","url": "https://rpc.ankr.com/polygon_mumbai","body": {"jsonrpc": "2.0","method": "eth_chainId","params": [],"id": 0},"header": {"content-type": "application/json"}}` sampleChainId, _ := hex.DecodeString("7b226a736f6e727063223a22322e30222c22726573756c74223a223078616133376463222c226964223a307d0a") + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 2, + MaxIdleConnsPerHost: 1, + }, + } + type args struct { urlStruct types.DataSourceURL } @@ -167,7 +179,11 @@ func TestGetDataFromAPI(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { localCache := cache.NewLocalCache(time.Second * 10) - got, err := GetDataFromAPI(tt.args.urlStruct, localCache) + commitParams := &types.CommitParams{ + LocalCache: localCache, + HttpClient: httpClient, + } + got, err := GetDataFromAPI(commitParams, tt.args.urlStruct) if (err != nil) != tt.wantErr { t.Errorf("GetDataFromAPI() error = %v, wantErr %v", err, tt.wantErr) return @@ -179,6 +195,63 @@ func TestGetDataFromAPI(t *testing.T) { } } +func TestParseJSONData(t *testing.T) { + tests := []struct { + name string + input string + selector string + expected interface{} + expectedErr string + }{ + { + name: "JSON Object", + input: `{"key1": "value1", "key2": "value2"}`, + selector: "key1", + expected: "value1", + }, + { + name: "Array of JSON Objects", + input: `[{"key1": "value1", "key2": "value2"}, {"key1": "value3", "key2": "value4"}]`, + selector: "key2", + expected: "value2", + }, + { + name: "Empty JSON Object", + input: `{}`, + selector: "key1", + expectedErr: "unknown key key1", + }, + { + name: "Empty JSON Array", + input: `[]`, + selector: "key1", + expectedErr: "empty JSON array", + }, + { + name: "Unexpected JSON Structure", + input: `"unexpected structure"`, + selector: "key1", + expectedErr: "unexpected JSON structure", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedJSON interface{} + err := json.Unmarshal([]byte(tt.input), &parsedJSON) + assert.NoError(t, err) + + result, err := parseJSONData(parsedJSON, tt.selector) + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + func TestGetDataFromJSON(t *testing.T) { type args struct { jsonObject map[string]interface{} diff --git a/utils/asset.go b/utils/asset.go index a08ed70fe..7f3602a41 100644 --- a/utils/asset.go +++ b/utils/asset.go @@ -14,6 +14,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/avast/retry-go" @@ -140,21 +141,21 @@ func (*UtilsStruct) GetActiveCollectionIds(client *ethclient.Client) ([]uint16, return activeCollectionIds, nil } -func (*UtilsStruct) GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, localCache *cache.LocalCache) (*big.Int, error) { - activeCollection, err := UtilsInterface.GetActiveCollection(client, collectionId) +func (*UtilsStruct) GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, commitParams *types.CommitParams) (*big.Int, error) { + activeCollection, err := UtilsInterface.GetActiveCollection(commitParams.CollectionsCache, collectionId) if err != nil { log.Error(err) return nil, err } //Supply previous epoch to Aggregate in case if last reported value is required. - collectionData, aggregationError := UtilsInterface.Aggregate(client, epoch-1, activeCollection, localCache) + collectionData, aggregationError := UtilsInterface.Aggregate(client, epoch-1, activeCollection, commitParams) if aggregationError != nil { return nil, aggregationError } return collectionData, nil } -func (*UtilsStruct) Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, localCache *cache.LocalCache) (*big.Int, error) { +func (*UtilsStruct) Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, commitParams *types.CommitParams) (*big.Int, error) { var jobs []bindings.StructsJob var overriddenJobIds []uint16 @@ -164,7 +165,7 @@ func (*UtilsStruct) Aggregate(client *ethclient.Client, previousEpoch uint32, co return nil, err } if _, err := path.OSUtilsInterface.Stat(assetsFilePath); !errors.Is(err, os.ErrNotExist) { - log.Debug("Fetching the jobs from assets.json file...") + log.Debugf("assets.json file is present, checking jobs for collection Id: %v...", collection.Id) jsonFile, err := path.OSUtilsInterface.Open(assetsFilePath) if err != nil { return nil, err @@ -183,25 +184,24 @@ func (*UtilsStruct) Aggregate(client *ethclient.Client, previousEpoch uint32, co } // Overriding the jobs from contracts with official jobs present in asset.go - overrideJobs, overriddenJobIdsFromJSONfile := UtilsInterface.HandleOfficialJobsFromJSONFile(client, collection, dataString) + overrideJobs, overriddenJobIdsFromJSONfile := UtilsInterface.HandleOfficialJobsFromJSONFile(client, collection, dataString, commitParams) jobs = append(jobs, overrideJobs...) overriddenJobIds = append(overriddenJobIds, overriddenJobIdsFromJSONfile...) // Also adding custom jobs to jobs array customJobs := GetCustomJobsFromJSONFile(collection.Name, dataString) if len(customJobs) != 0 { - log.Debugf("Got Custom Jobs from asset.json file: %+v", customJobs) + log.Debugf("Got Custom Jobs from asset.json file for collectionId %v: %+v", collection.Id, customJobs) } jobs = append(jobs, customJobs...) } for _, id := range collection.JobIDs { - // Ignoring the Jobs which are already overriden and added to jobs array if !Contains(overriddenJobIds, id) { - job, err := UtilsInterface.GetActiveJob(client, id) - if err != nil { - log.Errorf("Error in fetching job %d: %v", id, err) + job, isPresent := commitParams.JobsCache.GetJob(id) + if !isPresent { + log.Errorf("Job with id %v is not present in cache", id) continue } jobs = append(jobs, job) @@ -211,8 +211,8 @@ func (*UtilsStruct) Aggregate(client *ethclient.Client, previousEpoch uint32, co if len(jobs) == 0 { return nil, errors.New("no jobs present in the collection") } - dataToCommit, weight, err := UtilsInterface.GetDataToCommitFromJobs(jobs, localCache) - if err != nil || len(dataToCommit) == 0 { + dataToCommit, weight := UtilsInterface.GetDataToCommitFromJobs(jobs, commitParams) + if len(dataToCommit) == 0 { prevCommitmentData, err := UtilsInterface.FetchPreviousValue(client, previousEpoch, collection.Id) if err != nil { return nil, err @@ -242,10 +242,10 @@ func (*UtilsStruct) GetActiveJob(client *ethclient.Client, jobId uint16) (bindin return job, nil } -func (*UtilsStruct) GetActiveCollection(client *ethclient.Client, collectionId uint16) (bindings.StructsCollection, error) { - collection, err := UtilsInterface.GetCollection(client, collectionId) - if err != nil { - return bindings.StructsCollection{}, err +func (*UtilsStruct) GetActiveCollection(collectionsCache *cache.CollectionsCache, collectionId uint16) (bindings.StructsCollection, error) { + collection, isPresent := collectionsCache.GetCollection(collectionId) + if !isPresent { + return bindings.StructsCollection{}, errors.New("collection not present in cache") } if !collection.Active { return bindings.StructsCollection{}, errors.New("collection inactive") @@ -253,43 +253,58 @@ func (*UtilsStruct) GetActiveCollection(client *ethclient.Client, collectionId u return collection, nil } -func (*UtilsStruct) GetDataToCommitFromJobs(jobs []bindings.StructsJob, localCache *cache.LocalCache) ([]*big.Int, []uint8, error) { +func (*UtilsStruct) GetDataToCommitFromJobs(jobs []bindings.StructsJob, commitParams *types.CommitParams) ([]*big.Int, []uint8) { var ( + wg sync.WaitGroup + mu sync.Mutex data []*big.Int weight []uint8 ) + for _, job := range jobs { - dataToAppend, err := UtilsInterface.GetDataToCommitFromJob(job, localCache) - if err != nil { - continue - } - log.Debugf("Job %s gives data %s", job.Url, dataToAppend) - data = append(data, dataToAppend) - weight = append(weight, job.Weight) + wg.Add(1) + go processJobConcurrently(&wg, &mu, &data, &weight, job, commitParams) } - return data, weight, nil + + wg.Wait() + + return data, weight } -func (*UtilsStruct) GetDataToCommitFromJob(job bindings.StructsJob, localCache *cache.LocalCache) (*big.Int, error) { - var parsedJSON map[string]interface{} +func processJobConcurrently(wg *sync.WaitGroup, mu *sync.Mutex, data *[]*big.Int, weight *[]uint8, job bindings.StructsJob, commitParams *types.CommitParams) { + defer wg.Done() + + dataToAppend, err := UtilsInterface.GetDataToCommitFromJob(job, commitParams) + if err != nil { + return + } + log.Debugf("Job ID: %d, Job %s gives data %s", job.Id, job.Url, dataToAppend) + + mu.Lock() + defer mu.Unlock() + *data = append(*data, dataToAppend) + *weight = append(*weight, job.Weight) +} + +func (*UtilsStruct) GetDataToCommitFromJob(job bindings.StructsJob, commitParams *types.CommitParams) (*big.Int, error) { var ( response []byte apiErr error dataSourceURLStruct types.DataSourceURL ) - log.Debugf("Getting the data to commit for job %s having job Id %d", job.Name, job.Id) + log.Debugf("Job ID: %d, Getting the data to commit for job %s", job.Id, job.Name) if isJSONCompatible(job.Url) { - log.Debug("Job URL passed is a struct containing URL along with type of request data") + log.Debugf("Job ID: %d, Job URL passed is a struct containing URL along with type of request data", job.Id) dataSourceURLInBytes := []byte(job.Url) err := json.Unmarshal(dataSourceURLInBytes, &dataSourceURLStruct) if err != nil { - log.Errorf("Error in unmarshalling %s: %v", job.Url, err) + log.Errorf("Job ID: %d, Error in unmarshalling %s: %v", job.Id, job.Url, err) return nil, err } - log.Infof("URL Struct: %+v", dataSourceURLStruct) + log.Infof("Job ID: %d, URL Struct: %+v", job.Id, dataSourceURLStruct) } else { - log.Debug("Job URL passed is a direct URL: ", job.Url) + log.Debugf("Job ID: %d, Job URL passed is a direct URL: %s", job.Id, job.Url) re := regexp.MustCompile(core.APIKeyRegex) isAPIKeyRequired := re.MatchString(job.Url) if isAPIKeyRequired { @@ -306,29 +321,30 @@ func (*UtilsStruct) GetDataToCommitFromJob(job bindings.StructsJob, localCache * var parsedData interface{} if job.SelectorType == 0 { start := time.Now() - response, apiErr = GetDataFromAPI(dataSourceURLStruct, localCache) + response, apiErr = GetDataFromAPI(commitParams, dataSourceURLStruct) if apiErr != nil { - log.Errorf("Error in fetching data from API %s: %v", job.Url, apiErr) + log.Errorf("Job ID: %d, Error in fetching data from API %s: %v", job.Id, job.Url, apiErr) return nil, apiErr } elapsed := time.Since(start).Seconds() - log.Debugf("Time taken to fetch the data from API : %s was %f", dataSourceURLStruct.URL, elapsed) + log.Debugf("Job ID: %d, Time taken to fetch the data from API : %s was %f", job.Id, dataSourceURLStruct.URL, elapsed) + var parsedJSON interface{} err := json.Unmarshal(response, &parsedJSON) if err != nil { - log.Error("Error in parsing data from API: ", err) + log.Errorf("Job ID: %d, Error in parsing data from API: %v", job.Id, err) return nil, err } - parsedData, err = GetDataFromJSON(parsedJSON, job.Selector) + parsedData, err = parseJSONData(parsedJSON, job.Selector) if err != nil { - log.Error("Error in fetching value from parsed data: ", err) + log.Errorf("Job ID: %d, Error in parsing JSON data: %v", job.Id, err) return nil, err } } else { //TODO: Add retry here. dataPoint, err := GetDataFromXHTML(dataSourceURLStruct, job.Selector) if err != nil { - log.Error("Error in fetching value from parsed XHTML: ", err) + log.Errorf("Job ID: %d, Error in fetching value from parsed XHTML: %v", job.Id, err) return nil, err } // remove "," and currency symbols @@ -337,7 +353,7 @@ func (*UtilsStruct) GetDataToCommitFromJob(job bindings.StructsJob, localCache * datum, err := ConvertToNumber(parsedData, dataSourceURLStruct.ReturnType) if err != nil { - log.Error("Result is not a number") + log.Errorf("Job ID: %d, Result is not a number", job.Id) return nil, err } @@ -470,7 +486,7 @@ func ConvertCustomJobToStructJob(customJob types.CustomJob) bindings.StructsJob } } -func (*UtilsStruct) HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string) ([]bindings.StructsJob, []uint16) { +func (*UtilsStruct) HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string, commitParams *types.CommitParams) ([]bindings.StructsJob, []uint16) { var overrideJobs []bindings.StructsJob var overriddenJobIds []uint16 @@ -483,8 +499,9 @@ func (*UtilsStruct) HandleOfficialJobsFromJSONFile(client *ethclient.Client, col if officialJobsJSONResult.Exists() { officialJobs := officialJobsJSONResult.String() if officialJobs != "" { - job, err := UtilsInterface.GetActiveJob(client, jobIds[i]) - if err != nil { + job, isPresent := commitParams.JobsCache.GetJob(jobIds[i]) + if !isPresent { + log.Errorf("Job with id %v is not present in cache", jobIds[i]) continue } log.Debugf("Overriding job %s having jobId %d from official job present in assets.json file...", job.Url, job.Id) @@ -515,6 +532,54 @@ func (*UtilsStruct) HandleOfficialJobsFromJSONFile(client *ethclient.Client, col return overrideJobs, overriddenJobIds } +// InitJobsCache initializes the jobs cache with data fetched from the blockchain +func InitJobsCache(client *ethclient.Client, jobsCache *cache.JobsCache) error { + jobsCache.Mu.Lock() + defer jobsCache.Mu.Unlock() + + // Flush the jobsCache before initialization + for k := range jobsCache.Jobs { + delete(jobsCache.Jobs, k) + } + + numJobs, err := AssetManagerInterface.GetNumJobs(client) + if err != nil { + return err + } + for i := 1; i <= int(numJobs); i++ { + job, err := UtilsInterface.GetActiveJob(client, uint16(i)) + if err != nil { + return err + } + jobsCache.Jobs[job.Id] = job + } + return nil +} + +// InitCollectionsCache initializes the collections cache with data fetched from the blockchain +func InitCollectionsCache(client *ethclient.Client, collectionsCache *cache.CollectionsCache) error { + collectionsCache.Mu.Lock() + defer collectionsCache.Mu.Unlock() + + // Flush the collectionsCache before initialization + for k := range collectionsCache.Collections { + delete(collectionsCache.Collections, k) + } + + numCollections, err := AssetManagerInterface.GetNumCollections(client) + if err != nil { + return err + } + for i := 1; i <= int(numCollections); i++ { + collection, err := AssetManagerInterface.GetCollection(client, uint16(i)) + if err != nil { + return err + } + collectionsCache.Collections[collection.Id] = collection + } + return nil +} + func ReplaceValueWithDataFromENVFile(re *regexp.Regexp, value string) string { // substrings denotes all the occurrences of substring which satisfies APIKeyRegex substrings := re.FindAllString(value, -1) diff --git a/utils/asset_test.go b/utils/asset_test.go index 88216dea9..e45bf8314 100644 --- a/utils/asset_test.go +++ b/utils/asset_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io/fs" "math/big" + "net/http" "os" "razor/cache" "razor/core" @@ -25,11 +26,18 @@ import ( ) func TestAggregate(t *testing.T) { - var client *ethclient.Client - var previousEpoch uint32 - var fileInfo fs.FileInfo + var ( + client *ethclient.Client + previousEpoch uint32 + fileInfo fs.FileInfo + ) - job := bindings.StructsJob{Id: 1, SelectorType: 1, Weight: 100, + job1 := bindings.StructsJob{Id: 1, SelectorType: 1, Weight: 100, + Power: 2, Name: "ethusd_gemini", Selector: "last", + Url: "https://api.gemini.com/v1/pubticker/ethusd", + } + + job2 := bindings.StructsJob{Id: 2, SelectorType: 1, Weight: 100, Power: 2, Name: "ethusd_gemini", Selector: "last", Url: "https://api.gemini.com/v1/pubticker/ethusd", } @@ -39,14 +47,13 @@ func TestAggregate(t *testing.T) { Id: 4, Power: 2, AggregationMethod: 2, - JobIDs: []uint16{1, 2, 3}, + JobIDs: []uint16{1, 2}, Name: "ethCollectionMean", } type args struct { collection bindings.StructsCollection - activeJob bindings.StructsJob - activeJobErr error + jobCacheError bool dataToCommit []*big.Int dataToCommitErr error weight []uint8 @@ -72,21 +79,20 @@ func TestAggregate(t *testing.T) { name: "Test 1: When Aggregate() executes successfully", args: args{ collection: collection, - activeJob: job, - dataToCommit: []*big.Int{big.NewInt(2)}, - weight: []uint8{100}, + dataToCommit: []*big.Int{big.NewInt(3827200), big.NewInt(3828474)}, + weight: []uint8{1, 1}, prevCommitmentData: big.NewInt(1), assetFilePath: "", statErr: nil, }, - want: big.NewInt(2), + want: big.NewInt(3827837), wantErr: false, }, { - name: "Test 2: When there is an error in getting activeJob", + name: "Test 2: When the job is not present in cache", args: args{ collection: collection, - activeJobErr: errors.New("activeJob error"), + jobCacheError: true, dataToCommit: []*big.Int{big.NewInt(2)}, weight: []uint8{100}, prevCommitmentData: big.NewInt(1), @@ -100,7 +106,6 @@ func TestAggregate(t *testing.T) { name: "Test 3: When there is an error in getting dataToCommit", args: args{ collection: collection, - activeJob: job, dataToCommitErr: errors.New("dataToCommit error"), weight: []uint8{100}, prevCommitmentData: big.NewInt(1), @@ -112,7 +117,6 @@ func TestAggregate(t *testing.T) { name: "Test 4: When there is an error in getting prevCommitmentData", args: args{ collection: collection, - activeJob: job, dataToCommit: []*big.Int{big.NewInt(2)}, weight: []uint8{100}, prevCommitmentDataErr: errors.New("prevCommitmentData error"), @@ -124,7 +128,6 @@ func TestAggregate(t *testing.T) { name: "Test 5: When there is an error in getting prevCommitmentData", args: args{ collection: collection, - activeJob: job, dataToCommitErr: errors.New("dataToCommit error"), weight: []uint8{100}, prevCommitmentDataErr: errors.New("prevCommitmentData error"), @@ -172,6 +175,16 @@ func TestAggregate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + commitParams := &types.CommitParams{ + JobsCache: cache.NewJobsCache(), + CollectionsCache: cache.NewCollectionsCache(), + } + if !tt.args.jobCacheError { + commitParams.JobsCache.Jobs[job1.Id] = job1 + commitParams.JobsCache.Jobs[job2.Id] = job2 + commitParams.CollectionsCache.Collections[collection.Id] = collection + } + utilsMock := new(mocks.Utils) pathUtilsMock := new(pathMocks.PathInterface) osUtilsMock := new(pathMocks.OSInterface) @@ -185,16 +198,16 @@ func TestAggregate(t *testing.T) { path.OSUtilsInterface = osUtilsMock utils := StartRazor(optionsPackageStruct) - utilsMock.On("GetActiveJob", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint16")).Return(tt.args.activeJob, tt.args.activeJobErr) - utilsMock.On("GetDataToCommitFromJobs", mock.Anything, mock.Anything).Return(tt.args.dataToCommit, tt.args.weight, tt.args.dataToCommitErr) + utilsMock.On("GetDataToCommitFromJobs", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.dataToCommit, tt.args.weight, tt.args.dataToCommitErr) utilsMock.On("FetchPreviousValue", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint32"), mock.AnythingOfType("uint16")).Return(tt.args.prevCommitmentData, tt.args.prevCommitmentDataErr) pathUtilsMock.On("GetJobFilePath").Return(tt.args.assetFilePath, tt.args.assetFilePathErr) osUtilsMock.On("Stat", mock.Anything).Return(fileInfo, tt.args.statErr) osUtilsMock.On("Open", mock.Anything).Return(tt.args.jsonFile, tt.args.jsonFileErr) ioMock.On("ReadAll", mock.Anything).Return(tt.args.fileData, tt.args.fileDataErr) - utilsMock.On("HandleOfficialJobsFromJSONFile", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.overrrideJobs, tt.args.overrideJobIds) + utilsMock.On("HandleOfficialJobsFromJSONFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.overrrideJobs, tt.args.overrideJobIds) + + got, err := utils.Aggregate(client, previousEpoch, tt.args.collection, commitParams) - got, err := utils.Aggregate(client, previousEpoch, tt.args.collection, &cache.LocalCache{}) if (err != nil) != tt.wantErr { t.Errorf("Aggregate() error = %v, wantErr %v", err, tt.wantErr) return @@ -267,24 +280,21 @@ func TestGetActiveCollectionIds(t *testing.T) { } func TestGetActiveCollection(t *testing.T) { - var client *ethclient.Client - var collectionId uint16 - - collectionEth := bindings.StructsCollection{Active: true, - Id: 2, - Power: 2, - AggregationMethod: 2, - JobIDs: []uint16{1, 2}, - Name: "ethCollectionMean", + collectionEth := bindings.StructsCollection{ + Active: true, Id: 1, Power: 2, + AggregationMethod: 2, JobIDs: []uint16{1, 2}, + Name: "ethCollectionMean", } - collectionEthInactive := bindings.StructsCollection{Active: false, Id: 2, Power: 2, - AggregationMethod: 2, JobIDs: []uint16{1, 2}, Name: "ethCollectionMean", + collectionEthInactive := bindings.StructsCollection{ + Active: false, Id: 2, Power: 2, + AggregationMethod: 2, JobIDs: []uint16{1, 2}, + Name: "ethCollectionMean", } type args struct { - collection bindings.StructsCollection - collectionErr error + collectionId uint16 + collectionCacheErr bool } tests := []struct { name string @@ -295,15 +305,15 @@ func TestGetActiveCollection(t *testing.T) { { name: "Test 1: When GetActiveCollection() executes successfully", args: args{ - collection: collectionEth, + collectionId: 1, }, want: collectionEth, wantErr: false, }, { - name: "Test 2: When there is an error in getting collection", + name: "Test 2: When the collection is not present in cache", args: args{ - collectionErr: errors.New("collection error"), + collectionCacheErr: true, }, want: bindings.StructsCollection{}, wantErr: true, @@ -311,7 +321,7 @@ func TestGetActiveCollection(t *testing.T) { { name: "Test 3: When there is an inactive collection", args: args{ - collection: collectionEthInactive, + collectionId: 2, }, want: bindings.StructsCollection{}, wantErr: true, @@ -319,16 +329,12 @@ func TestGetActiveCollection(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - utilsMock := new(mocks.Utils) - - optionsPackageStruct := OptionsPackageStruct{ - UtilsInterface: utilsMock, - } - utils := StartRazor(optionsPackageStruct) + collectionCache := cache.NewCollectionsCache() + collectionCache.Collections[collectionEth.Id] = collectionEth + collectionCache.Collections[collectionEthInactive.Id] = collectionEthInactive - utilsMock.On("GetCollection", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint16")).Return(tt.args.collection, tt.args.collectionErr) - - got, err := utils.GetActiveCollection(client, collectionId) + utils := UtilsStruct{} + got, err := utils.GetActiveCollection(collectionCache, tt.args.collectionId) if (err != nil) != tt.wantErr { t.Errorf("GetActiveCollection() error = %v, wantErr %v", err, tt.wantErr) return @@ -553,6 +559,14 @@ func TestGetAllCollections(t *testing.T) { } func TestGetDataToCommitFromJobs(t *testing.T) { + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 2, + MaxIdleConnsPerHost: 1, + }, + } + jobsArray := []bindings.StructsJob{ {Id: 1, SelectorType: 0, Weight: 10, Power: 2, Name: "ethusd_gemini", Selector: "last", @@ -603,7 +617,6 @@ func TestGetDataToCommitFromJobs(t *testing.T) { name string args args wantArrayLength int - wantErr bool }{ { name: "Test 1: Getting values from set of jobs of length 4", @@ -630,14 +643,12 @@ func TestGetDataToCommitFromJobs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { UtilsInterface = &UtilsStruct{} - lc := cache.NewLocalCache(time.Second * 20) - - gotDataArray, gotWeightArray, err := UtilsInterface.GetDataToCommitFromJobs(tt.args.jobs, lc) - if (err != nil) != tt.wantErr { - t.Errorf("GetDataToCommitFromJob() error = %v, wantErr %v", err, tt.wantErr) - return + commitParams := &types.CommitParams{ + LocalCache: cache.NewLocalCache(time.Second * 10), + HttpClient: httpClient, } + gotDataArray, gotWeightArray := UtilsInterface.GetDataToCommitFromJobs(tt.args.jobs, commitParams) if len(gotDataArray) != tt.wantArrayLength || len(gotWeightArray) != tt.wantArrayLength { t.Errorf("GetDataToCommitFromJobs() got = %v, want %v", gotDataArray, tt.wantArrayLength) } @@ -648,6 +659,14 @@ func TestGetDataToCommitFromJobs(t *testing.T) { } func TestGetDataToCommitFromJob(t *testing.T) { + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 2, + MaxIdleConnsPerHost: 1, + }, + } + job := bindings.StructsJob{Id: 1, SelectorType: 0, Weight: 100, Power: 2, Name: "ethusd_kraken", Selector: "result.XETHZUSD.c[0]", Url: `{"type": "GET","url": "https://api.kraken.com/0/public/Ticker?pair=ETHUSD","body": {},"header": {}}`, @@ -668,6 +687,16 @@ func TestGetDataToCommitFromJob(t *testing.T) { Url: `{"type": "POST","url": "https://rpc.ankr.com/eth","body": {"jsonrpc":"2.0","id":7269270904970082,"method":"eth_call","params":[{"from":"0x0000000000000000000000000000000000000000","data":"0xd06ca61f0000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000200000000000000000000000050de6856358cc35f3a9a57eaaa34bd4cb707d2cd0000000000000000000000008e870d67f660d95d5be530380d0ec0bd388289e1","to":"0x7a250d5630b4cf539739df2c5dacb4c659f2488d"},"latest"]},"header": {"content-type": "application/json"}, "returnType": "hexArray[1]"}`, } + arrayOfObjectsJob := bindings.StructsJob{Id: 1, SelectorType: 0, Weight: 100, + Power: 2, Name: "ethusd_bitfinex", Selector: "last_price", + Url: "https://api.bitfinex.com/v1/pubticker/ethusd", + } + + arrayOfArraysJob := bindings.StructsJob{Id: 1, SelectorType: 0, Weight: 100, + Power: 2, Name: "ethusd_bitfinex_v2", Selector: "last_price", + Url: "https://api-pub.bitfinex.com/v2/tickers?symbols=tXDCUSD", + } + invalidDataSourceStructJob := bindings.StructsJob{Id: 1, SelectorType: 0, Weight: 100, Power: 2, Name: "ethusd_sample", Selector: "result", Url: `{"type": true,"url1": {}}`, @@ -733,6 +762,20 @@ func TestGetDataToCommitFromJob(t *testing.T) { want: nil, wantErr: false, }, + { + name: "Test 7: When GetDataToCommitFromJob() executes successfully for job returning response of type array of objects", + args: args{ + job: arrayOfObjectsJob, + }, + wantErr: false, + }, + { + name: "Test 8: When GetDataToCommitFromJob() fails for job returning response of type arrays of arrays as element in array is not a json object", + args: args{ + job: arrayOfArraysJob, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -743,12 +786,12 @@ func TestGetDataToCommitFromJob(t *testing.T) { } utils := StartRazor(optionsPackageStruct) - pathUtilsMock := new(pathMocks.PathInterface) - path.PathUtilsInterface = pathUtilsMock + commitParams := &types.CommitParams{ + LocalCache: cache.NewLocalCache(time.Second * 10), + HttpClient: httpClient, + } - pathUtilsMock.On("GetDotENVFilePath", mock.Anything).Return("$HOME/.razor/.env", nil) - lc := cache.NewLocalCache(time.Second * 10) - data, err := utils.GetDataToCommitFromJob(tt.args.job, lc) + data, err := utils.GetDataToCommitFromJob(tt.args.job, commitParams) fmt.Println("JOB returns data: ", data) if (err != nil) != tt.wantErr { t.Errorf("GetDataToCommitFromJob() error = %v, wantErr %v", err, tt.wantErr) @@ -1079,14 +1122,23 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { ethCollection1 := bindings.StructsCollection{ Active: true, Id: 7, Power: 2, - AggregationMethod: 2, JobIDs: []uint16{1, 2, 3}, Name: "ethCollection", + AggregationMethod: 2, JobIDs: []uint16{1, 2}, Name: "ethCollection", + } + + job1 := bindings.StructsJob{Id: 1, SelectorType: 0, Weight: 0, + Power: 2, Name: "ethusd_kucoin", Selector: "last", + Url: "http://kucoin.com/eth", + } + + job2 := bindings.StructsJob{Id: 2, SelectorType: 0, Weight: 2, + Power: 3, Name: "ethusd_coinbase", Selector: "eth2", + Url: "http://api.coinbase.com/eth", } type args struct { - collection bindings.StructsCollection - dataString string - job bindings.StructsJob - jobErr error + collection bindings.StructsCollection + dataString string + addJobToCache bool } tests := []struct { name string @@ -1097,19 +1149,19 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { { name: "Test 1: When officialJobs for collection is present in assets.json", args: args{ - collection: ethCollection, - dataString: jsonDataString, - job: bindings.StructsJob{ - Id: 1, - }, + collection: ethCollection, + dataString: jsonDataString, + addJobToCache: true, }, want: []bindings.StructsJob{ { - Id: 1, - Url: "http://kucoin.com/eth1", - Selector: "eth1", - Power: 2, - Weight: 2, + Id: 1, + SelectorType: 0, + Name: "ethusd_kucoin", + Url: "http://kucoin.com/eth1", + Selector: "eth1", + Power: 2, + Weight: 2, }, }, wantOverrideJobIds: []uint16{1}, @@ -1119,9 +1171,6 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { args: args{ collection: ethCollection, dataString: "", - job: bindings.StructsJob{ - Id: 1, - }, }, want: nil, wantOverrideJobIds: nil, @@ -1129,9 +1178,9 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { { name: "Test 3: When there is an error from GetActiveJob()", args: args{ - collection: ethCollection, - dataString: jsonDataString, - jobErr: errors.New("job error"), + collection: ethCollection, + dataString: jsonDataString, + addJobToCache: false, }, want: nil, wantOverrideJobIds: nil, @@ -1139,26 +1188,22 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { { name: "Test 4: When multiple jobIds are needed to be overridden from official jobs", args: args{ - collection: ethCollection1, - dataString: jsonDataString, - job: bindings.StructsJob{ - Id: 1, - Url: "http://kraken.com/eth1", - Selector: "data.ETH", - Power: 3, - Weight: 1, - }, + collection: ethCollection1, + dataString: jsonDataString, + addJobToCache: true, }, want: []bindings.StructsJob{ { Id: 1, + Name: "ethusd_kucoin", Url: "http://kucoin.com/eth1", Selector: "eth1", Power: 2, Weight: 2, }, { - Id: 1, + Id: 2, + Name: "ethusd_coinbase", Url: "http://api.coinbase.com/eth2", Selector: "eth2", Power: 3, @@ -1170,15 +1215,17 @@ func TestHandleOfficialJobsFromJSONFile(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - utilsMock := new(mocks.Utils) - utilsMock.On("GetActiveJob", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("uint16")).Return(tt.args.job, tt.args.jobErr) - - optionsPackageStruct := OptionsPackageStruct{ - UtilsInterface: utilsMock, + commitParams := &types.CommitParams{ + JobsCache: cache.NewJobsCache(), } - utils := StartRazor(optionsPackageStruct) + if tt.args.addJobToCache { + commitParams.JobsCache.Jobs[job1.Id] = job1 + commitParams.JobsCache.Jobs[job2.Id] = job2 + } + + utils := &UtilsStruct{} - gotJobs, gotOverrideJobIds := utils.HandleOfficialJobsFromJSONFile(client, tt.args.collection, tt.args.dataString) + gotJobs, gotOverrideJobIds := utils.HandleOfficialJobsFromJSONFile(client, tt.args.collection, tt.args.dataString, commitParams) if !reflect.DeepEqual(gotJobs, tt.want) { t.Errorf("HandleOfficialJobsFromJSONFile() gotJobs = %v, want %v", gotJobs, tt.want) } @@ -1281,9 +1328,9 @@ func TestGetAggregatedDataOfCollection(t *testing.T) { utils := StartRazor(optionsPackageStruct) utilsMock.On("GetActiveCollection", mock.Anything, mock.Anything).Return(tt.args.activeCollection, tt.args.activeCollectionErr) - utilsMock.On("Aggregate", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.collectionData, tt.args.aggregationErr) + utilsMock.On("Aggregate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.collectionData, tt.args.aggregationErr) - got, err := utils.GetAggregatedDataOfCollection(client, collectionId, epoch, &cache.LocalCache{}) + got, err := utils.GetAggregatedDataOfCollection(client, collectionId, epoch, &types.CommitParams{HttpClient: &http.Client{}}) if (err != nil) != tt.wantErr { t.Errorf("GetAggregatedDataOfCollection() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/utils/batch.go b/utils/batch.go new file mode 100644 index 000000000..a7a5bfdbc --- /dev/null +++ b/utils/batch.go @@ -0,0 +1,132 @@ +package utils + +import ( + "context" + "errors" + "fmt" + "github.com/avast/retry-go" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient" + "github.com/ethereum/go-ethereum/rpc" + "razor/core" +) + +//Each batch call may require multiple arguments therefore defining args as [][]interface{} + +// BatchCall performs a batch call to the Ethereum client, using the provided contract ABI, address, method name, and arguments. +func (c ClientStruct) BatchCall(client *ethclient.Client, contractABI *abi.ABI, contractAddress, methodName string, args [][]interface{}) ([][]interface{}, error) { + calls, err := ClientInterface.CreateBatchCalls(contractABI, contractAddress, methodName, args) + if err != nil { + log.Errorf("Error in creating batch calls: %v", err) + return nil, err + } + + err = performBatchCallWithRetry(client, calls) + if err != nil { + log.Errorf("Error in performing batch call: %v", err) + return nil, err + } + + results, err := processBatchResults(contractABI, methodName, calls) + if err != nil { + log.Errorf("Error in processing batch call result: %v", err) + return nil, err + } + + return results, nil +} + +// CreateBatchCalls creates a slice of rpc.BatchElem, each representing an Ethereum call, using the provided ABI, contract address, method name, and arguments. +func (c ClientStruct) CreateBatchCalls(contractABI *abi.ABI, contractAddress, methodName string, args [][]interface{}) ([]rpc.BatchElem, error) { + var calls []rpc.BatchElem + + for _, arg := range args { + data, err := contractABI.Pack(methodName, arg...) + if err != nil { + log.Errorf("Failed to pack data for method %s: %v", methodName, err) + return nil, err + } + + calls = append(calls, rpc.BatchElem{ + Method: "eth_call", + Args: []interface{}{ + map[string]interface{}{ + "to": contractAddress, + "data": fmt.Sprintf("0x%x", data), + }, + "latest", + }, + Result: new(string), + }) + } + return calls, nil +} + +func (c ClientStruct) PerformBatchCall(client *ethclient.Client, calls []rpc.BatchElem) error { + err := client.Client().BatchCallContext(context.Background(), calls) + if err != nil { + return err + } + return nil +} + +// performBatchCallWithRetry performs the batch call to the Ethereum client with retry logic. +func performBatchCallWithRetry(client *ethclient.Client, calls []rpc.BatchElem) error { + err := retry.Do(func() error { + err := ClientInterface.PerformBatchCall(client, calls) + if err != nil { + log.Errorf("Error in performing batch call, retrying: %v", err) + return err + } + for _, call := range calls { + if call.Error != nil { + log.Errorf("Error in call result: %v", call.Error) + return call.Error + } + } + return nil + }, retry.Attempts(core.MaxRetries)) + + if err != nil { + log.Errorf("All attempts failed to perform batch call: %v", err) + return err + } + + return nil +} + +// processBatchResults processes the results of the batch call, unpacking the data using the provided ABI and method name. +func processBatchResults(contractABI *abi.ABI, methodName string, calls []rpc.BatchElem) ([][]interface{}, error) { + var results [][]interface{} + + for _, call := range calls { + if call.Error != nil { + log.Errorf("Error in call result: %v", call.Error) + return nil, call.Error + } + + result, ok := call.Result.(*string) + if !ok { + log.Error("Failed to type assert call result to *string") + return nil, errors.New("type asserting of batch call result error") + } + + if result == nil || *result == "" { + return nil, errors.New("empty batch call result") + } + + data := common.FromHex(*result) + if len(data) == 0 { + return nil, errors.New("empty hex data") + } + + unpackedData, err := contractABI.Unpack(methodName, data) + if err != nil { + return nil, errors.New("unpacking data error") + } + + results = append(results, unpackedData) + } + return results, nil +} diff --git a/utils/batch_test.go b/utils/batch_test.go new file mode 100644 index 000000000..00cdbc40b --- /dev/null +++ b/utils/batch_test.go @@ -0,0 +1,202 @@ +package utils + +import ( + "errors" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/ethclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "math/big" + "razor/core" + "razor/pkg/bindings" + "razor/utils/mocks" + "strings" + "testing" +) + +func TestBatchCall(t *testing.T) { + //Testing Batch call scenario for getting StakeSnapshot + var client *ethclient.Client + + voteManagerABI, _ := abi.JSON(strings.NewReader(bindings.VoteManagerMetaData.ABI)) + stakeManagerABI, _ := abi.JSON(strings.NewReader(bindings.StakeManagerMetaData.ABI)) + numberOfArguments := 3 + + type args struct { + contractABI *abi.ABI + contractAddress string + methodName string + createBatchCallsErr error + performBatchCallErr error + results []interface{} + callErrors []error + } + tests := []struct { + name string + args args + want [][]interface{} + wantErr bool + }{ + { + name: "Test 1: When batch call executes successfully", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + ptrString("0x000000000000000000000000000000000000000000000000000000000000000a"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000c"), + }, + callErrors: []error{nil, nil, nil}, + }, + want: [][]interface{}{ + {big.NewInt(10)}, + {big.NewInt(11)}, + {big.NewInt(12)}, + }, + wantErr: false, + }, + { + name: "Test 2: When one of batch calls throw an error", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + nil, + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000c"), + }, + callErrors: []error{errors.New("batch call error"), nil, nil}, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 3: When BatchCalls receives an result of invalid type which cannot be type asserted to *string", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + 42, // intentionally incorrect data type, + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000c"), + }, + callErrors: []error{nil, nil, nil}, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 4: When BatchCalls receives a nil result (empty batch call result error)", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + nil, + nil, + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + }, + callErrors: []error{nil, nil, nil}, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 5: When BatchCalls receives an empty result (empty hex data error)", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + ptrString("0x"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000c"), + }, + callErrors: []error{nil, nil, nil}, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 6: When incorrect ABI is provided for unpacking", + args: args{ + contractABI: &stakeManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + results: []interface{}{ + ptrString("0x000000000000000000000000000000000000000000000000000000000000000a"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000b"), + ptrString("0x000000000000000000000000000000000000000000000000000000000000000c"), + }, + callErrors: []error{nil, nil, nil}, + }, + want: nil, + wantErr: true, + }, + { + name: "Test 7: When there is an error in creating batch calls", + args: args{ + contractABI: &voteManagerABI, + contractAddress: core.VoteManagerAddress, + methodName: core.GetStakeSnapshotMethod, + createBatchCallsErr: errors.New("create batch calls error"), + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stakerIds []uint32 + for i := 1; i <= numberOfArguments; i++ { + stakerIds = append(stakerIds, uint32(i)) + } + + arguments := make([][]interface{}, len(stakerIds)) + for i, stakerId := range stakerIds { + arguments[i] = []interface{}{uint32(100), stakerId} + } + + ClientInterface = &ClientStruct{} + calls, err := ClientInterface.CreateBatchCalls(tt.args.contractABI, tt.args.contractAddress, tt.args.methodName, arguments) + if err != nil { + log.Error("Error in creating batch calls: ", err) + return + } + // Mock batch call responses + for i, result := range tt.args.results { + if result != nil { + calls[i].Result = result + } + calls[i].Error = tt.args.callErrors[i] + } + clientMock := new(mocks.ClientUtils) + optionsPackageStruct := OptionsPackageStruct{ + ClientInterface: clientMock, + } + + StartRazor(optionsPackageStruct) + + clientMock.On("CreateBatchCalls", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(calls, tt.args.createBatchCallsErr) + clientMock.On("PerformBatchCall", mock.Anything, mock.Anything).Return(tt.args.performBatchCallErr) + + c := ClientStruct{} + gotResults, err := c.BatchCall(client, tt.args.contractABI, tt.args.contractAddress, tt.args.methodName, arguments) + if (err != nil) != tt.wantErr { + t.Errorf("BatchCall() error = %v, but wantErr bool is %v", err, tt.wantErr) + return + } + if !assert.Equal(t, gotResults, tt.want) { + t.Errorf("BatchCall() got = %v, want %v", gotResults, tt.want) + } + }) + } +} + +func ptrString(s string) *string { + return &s +} diff --git a/utils/common.go b/utils/common.go index 164c3a74a..3167e77ce 100644 --- a/utils/common.go +++ b/utils/common.go @@ -3,9 +3,11 @@ package utils import ( "context" "errors" + Types "github.com/ethereum/go-ethereum/core/types" "math/big" "os" "path/filepath" + "razor/accounts" "razor/core" "razor/core/types" "razor/logger" @@ -51,21 +53,17 @@ func (*UtilsStruct) FetchBalance(client *ethclient.Client, accountAddress string return balance, nil } -func (*UtilsStruct) GetBufferedState(client *ethclient.Client, buffer int32) (int64, error) { - block, err := ClientInterface.GetLatestBlockWithRetry(client) - if err != nil { - return -1, err - } +func (*UtilsStruct) GetBufferedState(client *ethclient.Client, header *Types.Header, buffer int32) (int64, error) { stateBuffer, err := UtilsInterface.GetStateBuffer(client) if err != nil { return -1, err } lowerLimit := (core.StateLength * uint64(buffer)) / 100 upperLimit := core.StateLength - (core.StateLength*uint64(buffer))/100 - if block.Time%(core.StateLength) > upperLimit-stateBuffer || block.Time%(core.StateLength) < lowerLimit+stateBuffer { + if header.Time%(core.StateLength) > upperLimit-stateBuffer || header.Time%(core.StateLength) < lowerLimit+stateBuffer { return -1, nil } - state := block.Time / core.StateLength + state := header.Time / core.StateLength return int64(state % core.NumberOfStates), nil } @@ -379,14 +377,8 @@ func (*FileStruct) ReadFromDisputeJsonFile(filePath string) (types.DisputeFileDa return disputeData, nil } -func (*UtilsStruct) CheckPassword(address string, password string) error { - razorPath, err := PathInterface.GetDefaultPath() - if err != nil { - log.Error("CheckPassword: Error in getting .razor path: ", err) - return err - } - keystorePath := filepath.Join(razorPath, "keystore_files") - _, err = AccountsInterface.GetPrivateKey(address, password, keystorePath) +func (*UtilsStruct) CheckPassword(account types.Account) error { + _, err := account.AccountManager.GetPrivateKey(account.Address, account.Password) if err != nil { log.Info("Kindly check your password!") log.Error("CheckPassword: Error in getting private key: ", err) @@ -394,3 +386,15 @@ func (*UtilsStruct) CheckPassword(address string, password string) error { } return nil } + +func (*UtilsStruct) AccountManagerForKeystore() (types.AccountManagerInterface, error) { + razorPath, err := PathInterface.GetDefaultPath() + if err != nil { + log.Error("GetKeystorePath: Error in getting .razor path: ", err) + return nil, err + } + keystorePath := filepath.Join(razorPath, "keystore_files") + + accountManager := accounts.NewAccountManager(keystorePath) + return accountManager, nil +} diff --git a/utils/common_test.go b/utils/common_test.go index 0e952f600..4ab53a83b 100644 --- a/utils/common_test.go +++ b/utils/common_test.go @@ -4,6 +4,7 @@ import ( "errors" "math/big" "os" + "razor/accounts" Types "razor/core/types" "razor/pkg/bindings" "razor/utils/mocks" @@ -365,7 +366,6 @@ func TestGetBufferedState(t *testing.T) { type args struct { block *types.Header - blockErr error buffer int32 stateBuffer uint64 stateBufferErr error @@ -390,18 +390,7 @@ func TestGetBufferedState(t *testing.T) { wantErr: false, }, { - name: "Test 2: When there is an error in getting block", - args: args{ - block: &types.Header{ - Number: big.NewInt(100), - }, - blockErr: errors.New("block error"), - }, - want: -1, - wantErr: true, - }, - { - name: "Test 3: When blockNumber%(core.StateLength) is greater than lowerLimit", + name: "Test 2: When blockNumber%(core.StateLength) is greater than lowerLimit", args: args{ block: &types.Header{ Time: 1080, @@ -413,7 +402,7 @@ func TestGetBufferedState(t *testing.T) { wantErr: false, }, { - name: "Test 4: When GetBufferedState() executes successfully and state we get is other than 0", + name: "Test 3: When GetBufferedState() executes successfully and state we get is other than 0", args: args{ block: &types.Header{ Time: 900, @@ -426,7 +415,7 @@ func TestGetBufferedState(t *testing.T) { wantErr: false, }, { - name: "Test 5: When there is an error in getting stateBuffer", + name: "Test 4: When there is an error in getting stateBuffer", args: args{ block: &types.Header{ Time: 100, @@ -453,9 +442,8 @@ func TestGetBufferedState(t *testing.T) { utils := StartRazor(optionsPackageStruct) utilsMock.On("GetStateBuffer", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.stateBuffer, tt.args.stateBufferErr) - clientUtilsMock.On("GetLatestBlockWithRetry", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.block, tt.args.blockErr) - got, err := utils.GetBufferedState(client, tt.args.buffer) + got, err := utils.GetBufferedState(client, tt.args.block, tt.args.buffer) if (err != nil) != tt.wantErr { t.Errorf("GetBufferedState() error = %v, wantErr %v", err, tt.wantErr) return @@ -1559,3 +1547,101 @@ func TestReadFromDisputeJsonFile(t *testing.T) { }) } } + +func TestUtilsStruct_CheckPassword(t *testing.T) { + type args struct { + account Types.Account + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Test 1: When password is correct", + args: args{ + account: Types.Account{ + Address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", + Password: "Test@123", + AccountManager: accounts.NewAccountManager("test_accounts"), + }, + }, + wantErr: false, + }, + { + name: "Test 2: When password is incorrect", + args: args{ + account: Types.Account{ + Address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", + Password: "Test@456", + AccountManager: accounts.NewAccountManager("test_accounts"), + }, + }, + wantErr: true, + }, + { + name: "Test 3: When address or keystore path provided is not present", + args: args{ + account: Types.Account{ + Address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", + Password: "Test@123", + AccountManager: accounts.NewAccountManager("test_accounts_1"), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ut := &UtilsStruct{} + if err := ut.CheckPassword(tt.args.account); (err != nil) != tt.wantErr { + t.Errorf("CheckPassword() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestUtilsStruct_AccountManagerForKeystore(t *testing.T) { + type args struct { + path string + pathErr error + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Test 1: When account manager for keystore is returned successfully", + args: args{ + path: "test_accounts", + }, + wantErr: false, + }, + { + name: "Test 2: When there is an error in getting path", + args: args{ + pathErr: errors.New("path error"), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pathMock := new(mocks.PathUtils) + optionsPackageStruct := OptionsPackageStruct{ + PathInterface: pathMock, + } + + utils := StartRazor(optionsPackageStruct) + + pathMock.On("GetDefaultPath").Return(tt.args.path, tt.args.pathErr) + + _, err := utils.AccountManagerForKeystore() + if (err != nil) != tt.wantErr { + t.Errorf("AccountManagerForKeystore() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/utils/interface.go b/utils/interface.go index fb2a4935c..cc7ee2274 100644 --- a/utils/interface.go +++ b/utils/interface.go @@ -3,6 +3,7 @@ package utils import ( "context" "crypto/ecdsa" + "github.com/ethereum/go-ethereum/rpc" "io" "io/fs" "math/big" @@ -32,7 +33,6 @@ import ( //go:generate mockery --name ABIUtils --output ./mocks --case=underscore //go:generate mockery --name PathUtils --output ./mocks --case=underscore //go:generate mockery --name BindUtils --output ./mocks --case=underscore -//go:generate mockery --name AccountsUtils --output ./mocks --case=underscore //go:generate mockery --name BlockManagerUtils --output ./mocks --case=underscore //go:generate mockery --name AssetManagerUtils --output ./mocks --case=underscore //go:generate mockery --name VoteManagerUtils --output ./mocks --case=underscore @@ -56,7 +56,6 @@ var IOInterface IOUtils var ABIInterface ABIUtils var PathInterface PathUtils var BindInterface BindUtils -var AccountsInterface AccountsUtils var BlockManagerInterface BlockManagerUtils var StakeManagerInterface StakeManagerUtils var AssetManagerInterface AssetManagerUtils @@ -111,23 +110,23 @@ type Utils interface { GetNumCollections(client *ethclient.Client) (uint16, error) GetActiveJob(client *ethclient.Client, jobId uint16) (bindings.StructsJob, error) GetCollection(client *ethclient.Client, collectionId uint16) (bindings.StructsCollection, error) - GetActiveCollection(client *ethclient.Client, collectionId uint16) (bindings.StructsCollection, error) - Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, localCache *cache.LocalCache) (*big.Int, error) - GetDataToCommitFromJobs(jobs []bindings.StructsJob, localCache *cache.LocalCache) ([]*big.Int, []uint8, error) - GetDataToCommitFromJob(job bindings.StructsJob, localCache *cache.LocalCache) (*big.Int, error) + GetActiveCollection(collectionsCache *cache.CollectionsCache, collectionId uint16) (bindings.StructsCollection, error) + Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, commitParams *types.CommitParams) (*big.Int, error) + GetDataToCommitFromJobs(jobs []bindings.StructsJob, commitParams *types.CommitParams) ([]*big.Int, []uint8) + GetDataToCommitFromJob(job bindings.StructsJob, commitParams *types.CommitParams) (*big.Int, error) GetAssignedCollections(client *ethclient.Client, numActiveCollections uint16, seed []byte) (map[int]bool, []*big.Int, error) GetLeafIdOfACollection(client *ethclient.Client, collectionId uint16) (uint16, error) GetCollectionIdFromIndex(client *ethclient.Client, medianIndex uint16) (uint16, error) GetCollectionIdFromLeafId(client *ethclient.Client, leafId uint16) (uint16, error) GetNumActiveCollections(client *ethclient.Client) (uint16, error) - GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, localCache *cache.LocalCache) (*big.Int, error) + GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, commitParams *types.CommitParams) (*big.Int, error) GetJobs(client *ethclient.Client) ([]bindings.StructsJob, error) GetAllCollections(client *ethclient.Client) ([]bindings.StructsCollection, error) GetActiveCollectionIds(client *ethclient.Client) ([]uint16, error) - HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string) ([]bindings.StructsJob, []uint16) + HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string, commitParams *types.CommitParams) ([]bindings.StructsJob, []uint16) ConnectToClient(provider string) *ethclient.Client FetchBalance(client *ethclient.Client, accountAddress string) (*big.Int, error) - GetBufferedState(client *ethclient.Client, buffer int32) (int64, error) + GetBufferedState(client *ethclient.Client, header *Types.Header, buffer int32) (int64, error) WaitForBlockCompletion(client *ethclient.Client, hashToRead string) error CheckEthBalanceIsZero(client *ethclient.Client, address string) AssignStakerId(flagSet *pflag.FlagSet, client *ethclient.Client, address string) (uint32, error) @@ -157,7 +156,8 @@ type Utils interface { GetRogueRandomValue(value int) *big.Int GetStakedTokenManagerWithOpts(client *ethclient.Client, tokenAddress common.Address) (*bindings.StakedToken, bind.CallOpts) GetStakerSRZRBalance(client *ethclient.Client, staker bindings.StructsStaker) (*big.Int, error) - CheckPassword(address string, password string) error + CheckPassword(account types.Account) error + AccountManagerForKeystore() (types.AccountManagerInterface, error) } type EthClientUtils interface { @@ -178,6 +178,9 @@ type ClientUtils interface { FilterLogsWithRetry(client *ethclient.Client, query ethereum.FilterQuery) ([]Types.Log, error) BalanceAtWithRetry(client *ethclient.Client, account common.Address) (*big.Int, error) GetNonceAtWithRetry(client *ethclient.Client, accountAddress common.Address) (uint64, error) + PerformBatchCall(client *ethclient.Client, calls []rpc.BatchElem) error + CreateBatchCalls(contractABI *abi.ABI, contractAddress, methodName string, args [][]interface{}) ([]rpc.BatchElem, error) + BatchCall(client *ethclient.Client, contractABI *abi.ABI, contractAddress, methodName string, args [][]interface{}) ([][]interface{}, error) } type TimeUtils interface { @@ -218,10 +221,6 @@ type BindUtils interface { NewKeyedTransactorWithChainID(key *ecdsa.PrivateKey, chainID *big.Int) (*bind.TransactOpts, error) } -type AccountsUtils interface { - GetPrivateKey(address string, password string, keystorePath string) (*ecdsa.PrivateKey, error) -} - type BlockManagerUtils interface { GetNumProposedBlocks(client *ethclient.Client, epoch uint32) (uint8, error) GetProposedBlock(client *ethclient.Client, epoch uint32, proposedBlock uint32) (bindings.StructsBlock, error) @@ -322,7 +321,6 @@ type IOStruct struct{} type ABIStruct struct{} type PathStruct struct{} type BindStruct struct{} -type AccountsStruct struct{} type BlockManagerStruct struct{} type StakeManagerStruct struct{} type AssetManagerStruct struct{} @@ -347,7 +345,6 @@ type OptionsPackageStruct struct { ABIInterface ABIUtils PathInterface PathUtils BindInterface BindUtils - AccountsInterface AccountsUtils BlockManagerInterface BlockManagerUtils StakeManagerInterface StakeManagerUtils AssetManagerInterface AssetManagerUtils diff --git a/utils/math.go b/utils/math.go index 976ec9202..5312b64a6 100644 --- a/utils/math.go +++ b/utils/math.go @@ -3,6 +3,7 @@ package utils import ( "crypto/rand" "errors" + "github.com/ethereum/go-ethereum/common" "math" "math/big" mathRand "math/rand" @@ -294,3 +295,11 @@ func isHexArrayPattern(s string) bool { re := regexp.MustCompile(pattern) return re.MatchString(s) } + +func ConvertHashToUint16(hash common.Hash) uint16 { + // Convert the hash to a big integer to handle the numeric value + bigIntValue := hash.Big() + + // Convert the big integer to uint64 first (safe for down casting to uint16) and then downcast to uint16 + return uint16(bigIntValue.Uint64()) +} diff --git a/utils/math_test.go b/utils/math_test.go index 9106c9fd7..0772dad53 100644 --- a/utils/math_test.go +++ b/utils/math_test.go @@ -1,6 +1,7 @@ package utils import ( + "github.com/ethereum/go-ethereum/common" "math/big" "razor/utils/mocks" "reflect" @@ -1297,3 +1298,46 @@ func Test_isHexArrayPattern(t *testing.T) { }) } } + +func TestConvertHashToUint16(t *testing.T) { + tests := []struct { + name string + hash common.Hash + expected uint16 + }{ + { + name: "ZeroHash", + hash: common.Hash{}, + expected: 0, + }, + { + name: "SmallNumber", + hash: common.BigToHash(big.NewInt(42)), + expected: 42, + }, + { + name: "MaxUint16", + hash: common.BigToHash(big.NewInt(65535)), + expected: 65535, + }, + { + name: "OverflowUint16", + hash: common.BigToHash(big.NewInt(65536)), + expected: 0, // 65536 % 65536 == 0 + }, + { + name: "LargeNumber", + hash: common.BigToHash(big.NewInt(123456789)), + expected: 52501, // 123456789 % 65536 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertHashToUint16(tt.hash) + if result != tt.expected { + t.Errorf("ConvertHashToUint16(%v) = %v, expected %v", tt.hash, result, tt.expected) + } + }) + } +} diff --git a/utils/mocks/client_utils.go b/utils/mocks/client_utils.go index 703c4b5e2..73cdfe6ac 100644 --- a/utils/mocks/client_utils.go +++ b/utils/mocks/client_utils.go @@ -1,19 +1,24 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.30.1. DO NOT EDIT. package mocks import ( - context "context" big "math/big" + abi "github.com/ethereum/go-ethereum/accounts/abi" + common "github.com/ethereum/go-ethereum/common" + context "context" + ethclient "github.com/ethereum/go-ethereum/ethclient" ethereum "github.com/ethereum/go-ethereum" mock "github.com/stretchr/testify/mock" + rpc "github.com/ethereum/go-ethereum/rpc" + types "github.com/ethereum/go-ethereum/core/types" ) @@ -27,6 +32,10 @@ func (_m *ClientUtils) BalanceAt(client *ethclient.Client, ctx context.Context, ret := _m.Called(client, ctx, account, blockNumber) var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Address, *big.Int) (*big.Int, error)); ok { + return rf(client, ctx, account, blockNumber) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Address, *big.Int) *big.Int); ok { r0 = rf(client, ctx, account, blockNumber) } else { @@ -35,7 +44,6 @@ func (_m *ClientUtils) BalanceAt(client *ethclient.Client, ctx context.Context, } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, common.Address, *big.Int) error); ok { r1 = rf(client, ctx, account, blockNumber) } else { @@ -50,6 +58,10 @@ func (_m *ClientUtils) BalanceAtWithRetry(client *ethclient.Client, account comm ret := _m.Called(client, account) var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, common.Address) (*big.Int, error)); ok { + return rf(client, account) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, common.Address) *big.Int); ok { r0 = rf(client, account) } else { @@ -58,7 +70,6 @@ func (_m *ClientUtils) BalanceAtWithRetry(client *ethclient.Client, account comm } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, common.Address) error); ok { r1 = rf(client, account) } else { @@ -68,18 +79,73 @@ func (_m *ClientUtils) BalanceAtWithRetry(client *ethclient.Client, account comm return r0, r1 } +// BatchCall provides a mock function with given fields: client, contractABI, contractAddress, methodName, args +func (_m *ClientUtils) BatchCall(client *ethclient.Client, contractABI *abi.ABI, contractAddress string, methodName string, args [][]interface{}) ([][]interface{}, error) { + ret := _m.Called(client, contractABI, contractAddress, methodName, args) + + var r0 [][]interface{} + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, *abi.ABI, string, string, [][]interface{}) ([][]interface{}, error)); ok { + return rf(client, contractABI, contractAddress, methodName, args) + } + if rf, ok := ret.Get(0).(func(*ethclient.Client, *abi.ABI, string, string, [][]interface{}) [][]interface{}); ok { + r0 = rf(client, contractABI, contractAddress, methodName, args) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([][]interface{}) + } + } + + if rf, ok := ret.Get(1).(func(*ethclient.Client, *abi.ABI, string, string, [][]interface{}) error); ok { + r1 = rf(client, contractABI, contractAddress, methodName, args) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateBatchCalls provides a mock function with given fields: contractABI, contractAddress, methodName, args +func (_m *ClientUtils) CreateBatchCalls(contractABI *abi.ABI, contractAddress string, methodName string, args [][]interface{}) ([]rpc.BatchElem, error) { + ret := _m.Called(contractABI, contractAddress, methodName, args) + + var r0 []rpc.BatchElem + var r1 error + if rf, ok := ret.Get(0).(func(*abi.ABI, string, string, [][]interface{}) ([]rpc.BatchElem, error)); ok { + return rf(contractABI, contractAddress, methodName, args) + } + if rf, ok := ret.Get(0).(func(*abi.ABI, string, string, [][]interface{}) []rpc.BatchElem); ok { + r0 = rf(contractABI, contractAddress, methodName, args) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]rpc.BatchElem) + } + } + + if rf, ok := ret.Get(1).(func(*abi.ABI, string, string, [][]interface{}) error); ok { + r1 = rf(contractABI, contractAddress, methodName, args) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // EstimateGas provides a mock function with given fields: client, ctx, msg func (_m *ClientUtils) EstimateGas(client *ethclient.Client, ctx context.Context, msg ethereum.CallMsg) (uint64, error) { ret := _m.Called(client, ctx, msg) var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, ethereum.CallMsg) (uint64, error)); ok { + return rf(client, ctx, msg) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, ethereum.CallMsg) uint64); ok { r0 = rf(client, ctx, msg) } else { r0 = ret.Get(0).(uint64) } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, ethereum.CallMsg) error); ok { r1 = rf(client, ctx, msg) } else { @@ -94,13 +160,16 @@ func (_m *ClientUtils) EstimateGasWithRetry(client *ethclient.Client, message et ret := _m.Called(client, message) var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, ethereum.CallMsg) (uint64, error)); ok { + return rf(client, message) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, ethereum.CallMsg) uint64); ok { r0 = rf(client, message) } else { r0 = ret.Get(0).(uint64) } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, ethereum.CallMsg) error); ok { r1 = rf(client, message) } else { @@ -115,6 +184,10 @@ func (_m *ClientUtils) FilterLogs(client *ethclient.Client, ctx context.Context, ret := _m.Called(client, ctx, q) var r0 []types.Log + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, ethereum.FilterQuery) ([]types.Log, error)); ok { + return rf(client, ctx, q) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, ethereum.FilterQuery) []types.Log); ok { r0 = rf(client, ctx, q) } else { @@ -123,7 +196,6 @@ func (_m *ClientUtils) FilterLogs(client *ethclient.Client, ctx context.Context, } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, ethereum.FilterQuery) error); ok { r1 = rf(client, ctx, q) } else { @@ -138,6 +210,10 @@ func (_m *ClientUtils) FilterLogsWithRetry(client *ethclient.Client, query ether ret := _m.Called(client, query) var r0 []types.Log + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, ethereum.FilterQuery) ([]types.Log, error)); ok { + return rf(client, query) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, ethereum.FilterQuery) []types.Log); ok { r0 = rf(client, query) } else { @@ -146,7 +222,6 @@ func (_m *ClientUtils) FilterLogsWithRetry(client *ethclient.Client, query ether } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, ethereum.FilterQuery) error); ok { r1 = rf(client, query) } else { @@ -161,6 +236,10 @@ func (_m *ClientUtils) GetLatestBlockWithRetry(client *ethclient.Client) (*types ret := _m.Called(client) var r0 *types.Header + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client) (*types.Header, error)); ok { + return rf(client) + } if rf, ok := ret.Get(0).(func(*ethclient.Client) *types.Header); ok { r0 = rf(client) } else { @@ -169,7 +248,6 @@ func (_m *ClientUtils) GetLatestBlockWithRetry(client *ethclient.Client) (*types } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client) error); ok { r1 = rf(client) } else { @@ -184,13 +262,16 @@ func (_m *ClientUtils) GetNonceAtWithRetry(client *ethclient.Client, accountAddr ret := _m.Called(client, accountAddress) var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, common.Address) (uint64, error)); ok { + return rf(client, accountAddress) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, common.Address) uint64); ok { r0 = rf(client, accountAddress) } else { r0 = ret.Get(0).(uint64) } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, common.Address) error); ok { r1 = rf(client, accountAddress) } else { @@ -205,6 +286,10 @@ func (_m *ClientUtils) HeaderByNumber(client *ethclient.Client, ctx context.Cont ret := _m.Called(client, ctx, number) var r0 *types.Header + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, *big.Int) (*types.Header, error)); ok { + return rf(client, ctx, number) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, *big.Int) *types.Header); ok { r0 = rf(client, ctx, number) } else { @@ -213,7 +298,6 @@ func (_m *ClientUtils) HeaderByNumber(client *ethclient.Client, ctx context.Cont } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, *big.Int) error); ok { r1 = rf(client, ctx, number) } else { @@ -228,13 +312,16 @@ func (_m *ClientUtils) NonceAt(client *ethclient.Client, ctx context.Context, ac ret := _m.Called(client, ctx, account) var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Address) (uint64, error)); ok { + return rf(client, ctx, account) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Address) uint64); ok { r0 = rf(client, ctx, account) } else { r0 = ret.Get(0).(uint64) } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, common.Address) error); ok { r1 = rf(client, ctx, account) } else { @@ -244,11 +331,29 @@ func (_m *ClientUtils) NonceAt(client *ethclient.Client, ctx context.Context, ac return r0, r1 } +// PerformBatchCall provides a mock function with given fields: client, calls +func (_m *ClientUtils) PerformBatchCall(client *ethclient.Client, calls []rpc.BatchElem) error { + ret := _m.Called(client, calls) + + var r0 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, []rpc.BatchElem) error); ok { + r0 = rf(client, calls) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // SuggestGasPrice provides a mock function with given fields: client, ctx func (_m *ClientUtils) SuggestGasPrice(client *ethclient.Client, ctx context.Context) (*big.Int, error) { ret := _m.Called(client, ctx) var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context) (*big.Int, error)); ok { + return rf(client, ctx) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context) *big.Int); ok { r0 = rf(client, ctx) } else { @@ -257,7 +362,6 @@ func (_m *ClientUtils) SuggestGasPrice(client *ethclient.Client, ctx context.Con } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context) error); ok { r1 = rf(client, ctx) } else { @@ -272,6 +376,10 @@ func (_m *ClientUtils) SuggestGasPriceWithRetry(client *ethclient.Client) (*big. ret := _m.Called(client) var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client) (*big.Int, error)); ok { + return rf(client) + } if rf, ok := ret.Get(0).(func(*ethclient.Client) *big.Int); ok { r0 = rf(client) } else { @@ -280,7 +388,6 @@ func (_m *ClientUtils) SuggestGasPriceWithRetry(client *ethclient.Client) (*big. } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client) error); ok { r1 = rf(client) } else { @@ -295,6 +402,10 @@ func (_m *ClientUtils) TransactionReceipt(client *ethclient.Client, ctx context. ret := _m.Called(client, ctx, txHash) var r0 *types.Receipt + var r1 error + if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Hash) (*types.Receipt, error)); ok { + return rf(client, ctx, txHash) + } if rf, ok := ret.Get(0).(func(*ethclient.Client, context.Context, common.Hash) *types.Receipt); ok { r0 = rf(client, ctx, txHash) } else { @@ -303,7 +414,6 @@ func (_m *ClientUtils) TransactionReceipt(client *ethclient.Client, ctx context. } } - var r1 error if rf, ok := ret.Get(1).(func(*ethclient.Client, context.Context, common.Hash) error); ok { r1 = rf(client, ctx, txHash) } else { @@ -313,13 +423,12 @@ func (_m *ClientUtils) TransactionReceipt(client *ethclient.Client, ctx context. return r0, r1 } -type mockConstructorTestingTNewClientUtils interface { +// NewClientUtils creates a new instance of ClientUtils. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewClientUtils(t interface { mock.TestingT Cleanup(func()) -} - -// NewClientUtils creates a new instance of ClientUtils. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewClientUtils(t mockConstructorTestingTNewClientUtils) *ClientUtils { +}) *ClientUtils { mock := &ClientUtils{} mock.Mock.Test(t) diff --git a/utils/mocks/utils.go b/utils/mocks/utils.go index 378936dc0..869a2fbcb 100644 --- a/utils/mocks/utils.go +++ b/utils/mocks/utils.go @@ -12,6 +12,8 @@ import ( common "github.com/ethereum/go-ethereum/common" + coretypes "github.com/ethereum/go-ethereum/core/types" + ethclient "github.com/ethereum/go-ethereum/ethclient" mock "github.com/stretchr/testify/mock" @@ -26,6 +28,32 @@ type Utils struct { mock.Mock } +// AccountManagerForKeystore provides a mock function with given fields: +func (_m *Utils) AccountManagerForKeystore() (types.AccountManagerInterface, error) { + ret := _m.Called() + + var r0 types.AccountManagerInterface + var r1 error + if rf, ok := ret.Get(0).(func() (types.AccountManagerInterface, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() types.AccountManagerInterface); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.AccountManagerInterface) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // AddJobToJSON provides a mock function with given fields: fileName, job func (_m *Utils) AddJobToJSON(fileName string, job *types.StructsJob) error { ret := _m.Called(fileName, job) @@ -40,25 +68,25 @@ func (_m *Utils) AddJobToJSON(fileName string, job *types.StructsJob) error { return r0 } -// Aggregate provides a mock function with given fields: client, previousEpoch, collection, localCache -func (_m *Utils) Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, localCache *cache.LocalCache) (*big.Int, error) { - ret := _m.Called(client, previousEpoch, collection, localCache) +// Aggregate provides a mock function with given fields: client, previousEpoch, collection, commitParams +func (_m *Utils) Aggregate(client *ethclient.Client, previousEpoch uint32, collection bindings.StructsCollection, commitParams *types.CommitParams) (*big.Int, error) { + ret := _m.Called(client, previousEpoch, collection, commitParams) var r0 *big.Int var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, bindings.StructsCollection, *cache.LocalCache) (*big.Int, error)); ok { - return rf(client, previousEpoch, collection, localCache) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, bindings.StructsCollection, *types.CommitParams) (*big.Int, error)); ok { + return rf(client, previousEpoch, collection, commitParams) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, bindings.StructsCollection, *cache.LocalCache) *big.Int); ok { - r0 = rf(client, previousEpoch, collection, localCache) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint32, bindings.StructsCollection, *types.CommitParams) *big.Int); ok { + r0 = rf(client, previousEpoch, collection, commitParams) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*big.Int) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32, bindings.StructsCollection, *cache.LocalCache) error); ok { - r1 = rf(client, previousEpoch, collection, localCache) + if rf, ok := ret.Get(1).(func(*ethclient.Client, uint32, bindings.StructsCollection, *types.CommitParams) error); ok { + r1 = rf(client, previousEpoch, collection, commitParams) } else { r1 = ret.Error(1) } @@ -155,13 +183,13 @@ func (_m *Utils) CheckEthBalanceIsZero(client *ethclient.Client, address string) _m.Called(client, address) } -// CheckPassword provides a mock function with given fields: address, password -func (_m *Utils) CheckPassword(address string, password string) error { - ret := _m.Called(address, password) +// CheckPassword provides a mock function with given fields: account +func (_m *Utils) CheckPassword(account types.Account) error { + ret := _m.Called(account) var r0 error - if rf, ok := ret.Get(0).(func(string, string) error); ok { - r0 = rf(address, password) + if rf, ok := ret.Get(0).(func(types.Account) error); ok { + r0 = rf(account) } else { r0 = ret.Error(0) } @@ -199,32 +227,6 @@ func (_m *Utils) ConnectToClient(provider string) *ethclient.Client { return r0 } -// ConvertToNumber provides a mock function with given fields: num -func (_m *Utils) ConvertToNumber(num interface{}) (*big.Float, error) { - ret := _m.Called(num) - - var r0 *big.Float - var r1 error - if rf, ok := ret.Get(0).(func(interface{}) (*big.Float, error)); ok { - return rf(num) - } - if rf, ok := ret.Get(0).(func(interface{}) *big.Float); ok { - r0 = rf(num) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*big.Float) - } - } - - if rf, ok := ret.Get(1).(func(interface{}) error); ok { - r1 = rf(num) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // DeleteJobFromJSON provides a mock function with given fields: fileName, jobId func (_m *Utils) DeleteJobFromJSON(fileName string, jobId string) error { ret := _m.Called(fileName, jobId) @@ -317,23 +319,23 @@ func (_m *Utils) FetchPreviousValue(client *ethclient.Client, epoch uint32, asse return r0, r1 } -// GetActiveCollection provides a mock function with given fields: client, collectionId -func (_m *Utils) GetActiveCollection(client *ethclient.Client, collectionId uint16) (bindings.StructsCollection, error) { - ret := _m.Called(client, collectionId) +// GetActiveCollection provides a mock function with given fields: collectionsCache, collectionId +func (_m *Utils) GetActiveCollection(collectionsCache *cache.CollectionsCache, collectionId uint16) (bindings.StructsCollection, error) { + ret := _m.Called(collectionsCache, collectionId) var r0 bindings.StructsCollection var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16) (bindings.StructsCollection, error)); ok { - return rf(client, collectionId) + if rf, ok := ret.Get(0).(func(*cache.CollectionsCache, uint16) (bindings.StructsCollection, error)); ok { + return rf(collectionsCache, collectionId) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16) bindings.StructsCollection); ok { - r0 = rf(client, collectionId) + if rf, ok := ret.Get(0).(func(*cache.CollectionsCache, uint16) bindings.StructsCollection); ok { + r0 = rf(collectionsCache, collectionId) } else { r0 = ret.Get(0).(bindings.StructsCollection) } - if rf, ok := ret.Get(1).(func(*ethclient.Client, uint16) error); ok { - r1 = rf(client, collectionId) + if rf, ok := ret.Get(1).(func(*cache.CollectionsCache, uint16) error); ok { + r1 = rf(collectionsCache, collectionId) } else { r1 = ret.Error(1) } @@ -391,25 +393,25 @@ func (_m *Utils) GetActiveJob(client *ethclient.Client, jobId uint16) (bindings. return r0, r1 } -// GetAggregatedDataOfCollection provides a mock function with given fields: client, collectionId, epoch, localCache -func (_m *Utils) GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, localCache *cache.LocalCache) (*big.Int, error) { - ret := _m.Called(client, collectionId, epoch, localCache) +// GetAggregatedDataOfCollection provides a mock function with given fields: client, collectionId, epoch, commitParams +func (_m *Utils) GetAggregatedDataOfCollection(client *ethclient.Client, collectionId uint16, epoch uint32, commitParams *types.CommitParams) (*big.Int, error) { + ret := _m.Called(client, collectionId, epoch, commitParams) var r0 *big.Int var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16, uint32, *cache.LocalCache) (*big.Int, error)); ok { - return rf(client, collectionId, epoch, localCache) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16, uint32, *types.CommitParams) (*big.Int, error)); ok { + return rf(client, collectionId, epoch, commitParams) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16, uint32, *cache.LocalCache) *big.Int); ok { - r0 = rf(client, collectionId, epoch, localCache) + if rf, ok := ret.Get(0).(func(*ethclient.Client, uint16, uint32, *types.CommitParams) *big.Int); ok { + r0 = rf(client, collectionId, epoch, commitParams) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*big.Int) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, uint16, uint32, *cache.LocalCache) error); ok { - r1 = rf(client, collectionId, epoch, localCache) + if rf, ok := ret.Get(1).(func(*ethclient.Client, uint16, uint32, *types.CommitParams) error); ok { + r1 = rf(client, collectionId, epoch, commitParams) } else { r1 = ret.Error(1) } @@ -568,23 +570,23 @@ func (_m *Utils) GetBlockManagerWithOpts(client *ethclient.Client) (*bindings.Bl return r0, r1 } -// GetBufferedState provides a mock function with given fields: client, buffer -func (_m *Utils) GetBufferedState(client *ethclient.Client, buffer int32) (int64, error) { - ret := _m.Called(client, buffer) +// GetBufferedState provides a mock function with given fields: client, header, buffer +func (_m *Utils) GetBufferedState(client *ethclient.Client, header *coretypes.Header, buffer int32) (int64, error) { + ret := _m.Called(client, header, buffer) var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(*ethclient.Client, int32) (int64, error)); ok { - return rf(client, buffer) + if rf, ok := ret.Get(0).(func(*ethclient.Client, *coretypes.Header, int32) (int64, error)); ok { + return rf(client, header, buffer) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, int32) int64); ok { - r0 = rf(client, buffer) + if rf, ok := ret.Get(0).(func(*ethclient.Client, *coretypes.Header, int32) int64); ok { + r0 = rf(client, header, buffer) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(*ethclient.Client, int32) error); ok { - r1 = rf(client, buffer) + if rf, ok := ret.Get(1).(func(*ethclient.Client, *coretypes.Header, int32) error); ok { + r1 = rf(client, header, buffer) } else { r1 = ret.Error(1) } @@ -730,95 +732,25 @@ func (_m *Utils) GetCommitment(client *ethclient.Client, address string) (types. return r0, r1 } -// GetDataFromAPI provides a mock function with given fields: urlStruct, localCache -func (_m *Utils) GetDataFromAPI(urlStruct types.DataSourceURL, localCache *cache.LocalCache) ([]byte, error) { - ret := _m.Called(urlStruct, localCache) - - var r0 []byte - if rf, ok := ret.Get(0).(func(types.DataSourceURL, *cache.LocalCache) []byte); ok { - r0 = rf(urlStruct, localCache) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(types.DataSourceURL, *cache.LocalCache) error); ok { - r1 = rf(urlStruct, localCache) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetDataFromJSON provides a mock function with given fields: jsonObject, selector -func (_m *Utils) GetDataFromJSON(jsonObject map[string]interface{}, selector string) (interface{}, error) { - ret := _m.Called(jsonObject, selector) - - var r0 interface{} - var r1 error - if rf, ok := ret.Get(0).(func(map[string]interface{}, string) (interface{}, error)); ok { - return rf(jsonObject, selector) - } - if rf, ok := ret.Get(0).(func(map[string]interface{}, string) interface{}); ok { - r0 = rf(jsonObject, selector) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) - } - } - - if rf, ok := ret.Get(1).(func(map[string]interface{}, string) error); ok { - r1 = rf(jsonObject, selector) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetDataFromXHTML provides a mock function with given fields: urlStruct, selector -func (_m *Utils) GetDataFromXHTML(urlStruct types.DataSourceURL, selector string) (string, error) { - ret := _m.Called(urlStruct, selector) - - var r0 string - if rf, ok := ret.Get(0).(func(types.DataSourceURL, string) string); ok { - r0 = rf(urlStruct, selector) - } else { - r0 = ret.Get(0).(string) - } - - var r1 error - if rf, ok := ret.Get(1).(func(types.DataSourceURL, string) error); ok { - r1 = rf(urlStruct, selector) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetDataToCommitFromJob provides a mock function with given fields: job, localCache -func (_m *Utils) GetDataToCommitFromJob(job bindings.StructsJob, localCache *cache.LocalCache) (*big.Int, error) { - ret := _m.Called(job, localCache) +// GetDataToCommitFromJob provides a mock function with given fields: job, commitParams +func (_m *Utils) GetDataToCommitFromJob(job bindings.StructsJob, commitParams *types.CommitParams) (*big.Int, error) { + ret := _m.Called(job, commitParams) var r0 *big.Int var r1 error - if rf, ok := ret.Get(0).(func(bindings.StructsJob, *cache.LocalCache) (*big.Int, error)); ok { - return rf(job, localCache) + if rf, ok := ret.Get(0).(func(bindings.StructsJob, *types.CommitParams) (*big.Int, error)); ok { + return rf(job, commitParams) } - if rf, ok := ret.Get(0).(func(bindings.StructsJob, *cache.LocalCache) *big.Int); ok { - r0 = rf(job, localCache) + if rf, ok := ret.Get(0).(func(bindings.StructsJob, *types.CommitParams) *big.Int); ok { + r0 = rf(job, commitParams) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*big.Int) } } - if rf, ok := ret.Get(1).(func(bindings.StructsJob, *cache.LocalCache) error); ok { - r1 = rf(job, localCache) + if rf, ok := ret.Get(1).(func(bindings.StructsJob, *types.CommitParams) error); ok { + r1 = rf(job, commitParams) } else { r1 = ret.Error(1) } @@ -826,39 +758,32 @@ func (_m *Utils) GetDataToCommitFromJob(job bindings.StructsJob, localCache *cac return r0, r1 } -// GetDataToCommitFromJobs provides a mock function with given fields: jobs, localCache -func (_m *Utils) GetDataToCommitFromJobs(jobs []bindings.StructsJob, localCache *cache.LocalCache) ([]*big.Int, []uint8, error) { - ret := _m.Called(jobs, localCache) +// GetDataToCommitFromJobs provides a mock function with given fields: jobs, commitParams +func (_m *Utils) GetDataToCommitFromJobs(jobs []bindings.StructsJob, commitParams *types.CommitParams) ([]*big.Int, []uint8) { + ret := _m.Called(jobs, commitParams) var r0 []*big.Int var r1 []uint8 - var r2 error - if rf, ok := ret.Get(0).(func([]bindings.StructsJob, *cache.LocalCache) ([]*big.Int, []uint8, error)); ok { - return rf(jobs, localCache) + if rf, ok := ret.Get(0).(func([]bindings.StructsJob, *types.CommitParams) ([]*big.Int, []uint8)); ok { + return rf(jobs, commitParams) } - if rf, ok := ret.Get(0).(func([]bindings.StructsJob, *cache.LocalCache) []*big.Int); ok { - r0 = rf(jobs, localCache) + if rf, ok := ret.Get(0).(func([]bindings.StructsJob, *types.CommitParams) []*big.Int); ok { + r0 = rf(jobs, commitParams) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*big.Int) } } - if rf, ok := ret.Get(1).(func([]bindings.StructsJob, *cache.LocalCache) []uint8); ok { - r1 = rf(jobs, localCache) + if rf, ok := ret.Get(1).(func([]bindings.StructsJob, *types.CommitParams) []uint8); ok { + r1 = rf(jobs, commitParams) } else { if ret.Get(1) != nil { r1 = ret.Get(1).([]uint8) } } - if rf, ok := ret.Get(2).(func([]bindings.StructsJob, *cache.LocalCache) error); ok { - r2 = rf(jobs, localCache) - } else { - r2 = ret.Error(2) - } - - return r0, r1, r2 + return r0, r1 } // GetEpoch provides a mock function with given fields: client @@ -1813,25 +1738,25 @@ func (_m *Utils) GetWithdrawInitiationPeriod(client *ethclient.Client) (uint16, return r0, r1 } -// HandleOfficialJobsFromJSONFile provides a mock function with given fields: client, collection, dataString -func (_m *Utils) HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string) ([]bindings.StructsJob, []uint16) { - ret := _m.Called(client, collection, dataString) +// HandleOfficialJobsFromJSONFile provides a mock function with given fields: client, collection, dataString, commitParams +func (_m *Utils) HandleOfficialJobsFromJSONFile(client *ethclient.Client, collection bindings.StructsCollection, dataString string, commitParams *types.CommitParams) ([]bindings.StructsJob, []uint16) { + ret := _m.Called(client, collection, dataString, commitParams) var r0 []bindings.StructsJob var r1 []uint16 - if rf, ok := ret.Get(0).(func(*ethclient.Client, bindings.StructsCollection, string) ([]bindings.StructsJob, []uint16)); ok { - return rf(client, collection, dataString) + if rf, ok := ret.Get(0).(func(*ethclient.Client, bindings.StructsCollection, string, *types.CommitParams) ([]bindings.StructsJob, []uint16)); ok { + return rf(client, collection, dataString, commitParams) } - if rf, ok := ret.Get(0).(func(*ethclient.Client, bindings.StructsCollection, string) []bindings.StructsJob); ok { - r0 = rf(client, collection, dataString) + if rf, ok := ret.Get(0).(func(*ethclient.Client, bindings.StructsCollection, string, *types.CommitParams) []bindings.StructsJob); ok { + r0 = rf(client, collection, dataString, commitParams) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]bindings.StructsJob) } } - if rf, ok := ret.Get(1).(func(*ethclient.Client, bindings.StructsCollection, string) []uint16); ok { - r1 = rf(client, collection, dataString) + if rf, ok := ret.Get(1).(func(*ethclient.Client, bindings.StructsCollection, string, *types.CommitParams) []uint16); ok { + r1 = rf(client, collection, dataString, commitParams) } else { if ret.Get(1) != nil { r1 = ret.Get(1).([]uint16) diff --git a/utils/options.go b/utils/options.go index b64821544..21062e0cd 100644 --- a/utils/options.go +++ b/utils/options.go @@ -3,7 +3,6 @@ package utils import ( "context" "errors" - "path/filepath" "razor/core/types" "strings" @@ -28,14 +27,14 @@ func (*UtilsStruct) GetOptions() bind.CallOpts { func (*UtilsStruct) GetTxnOpts(transactionData types.TransactionOptions) *bind.TransactOpts { log.Debug("Getting transaction options...") - defaultPath, err := PathInterface.GetDefaultPath() - CheckError("Error in fetching default path: ", err) - keystorePath := filepath.Join(defaultPath, "keystore_files") - privateKey, err := AccountsInterface.GetPrivateKey(transactionData.AccountAddress, transactionData.Password, keystorePath) - if privateKey == nil || err != nil { - CheckError("Error in fetching private key: ", errors.New(transactionData.AccountAddress+" not present in razor-go")) + account := transactionData.Account + if account.AccountManager == nil { + log.Fatal("Account Manager in transaction data is not initialised") } - nonce, err := ClientInterface.GetNonceAtWithRetry(transactionData.Client, common.HexToAddress(transactionData.AccountAddress)) + privateKey, err := account.AccountManager.GetPrivateKey(account.Address, account.Password) + CheckError("Error in fetching private key: ", err) + + nonce, err := ClientInterface.GetNonceAtWithRetry(transactionData.Client, common.HexToAddress(account.Address)) CheckError("Error in fetching nonce: ", err) gasPrice := GasInterface.GetGasPrice(transactionData.Client, transactionData.Config) @@ -103,7 +102,7 @@ func (*GasStruct) GetGasLimit(transactionData types.TransactionOptions, txnOpts } contractAddress := common.HexToAddress(transactionData.ContractAddress) msg := ethereum.CallMsg{ - From: common.HexToAddress(transactionData.AccountAddress), + From: common.HexToAddress(transactionData.Account.Address), To: &contractAddress, GasPrice: txnOpts.GasPrice, Value: txnOpts.Value, diff --git a/utils/options_test.go b/utils/options_test.go index c7ecdfe71..54d7f9071 100644 --- a/utils/options_test.go +++ b/utils/options_test.go @@ -7,6 +7,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/crypto" "math/big" + "razor/accounts" "razor/core/types" "razor/utils/mocks" "reflect" @@ -17,7 +18,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" Types "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" - "github.com/magiconair/properties/assert" "github.com/stretchr/testify/mock" ) @@ -134,16 +134,13 @@ func Test_getGasPrice(t *testing.T) { } func Test_utils_GetTxnOpts(t *testing.T) { - var transactionData types.TransactionOptions var gasPrice *big.Int privateKey, _ := ecdsa.GenerateKey(crypto.S256(), rand.Reader) txnOpts, _ := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(1)) type args struct { - path string - pathErr error - privateKey *ecdsa.PrivateKey + address string nonce uint64 nonceErr error txnOpts *bind.TransactOpts @@ -162,36 +159,32 @@ func Test_utils_GetTxnOpts(t *testing.T) { { name: "Test 1: When GetTxnOptions execute successfully", args: args{ - path: "/home/local", - privateKey: privateKey, - nonce: 2, - txnOpts: txnOpts, - gasLimit: 1, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", + nonce: 2, + txnOpts: txnOpts, + gasLimit: 1, }, want: txnOpts, expectedFatal: false, }, { - name: "Test 2: When there is an error in getting path", + name: "Test 2: When there is an error in getting private key as address is not present in keystore", args: args{ - path: "/home/local", - pathErr: errors.New("path error"), - privateKey: privateKey, - nonce: 2, - txnOpts: txnOpts, - gasLimit: 1, + address: "0x77Baf83BAD5bee0F7F44d84669A50C35c57E3576", + nonce: 2, + txnOpts: txnOpts, + gasLimit: 1, }, want: txnOpts, expectedFatal: true, }, { - name: "Test 3: When the privateKey is nil", + name: "Test 3: When the accountManager is nil", args: args{ - path: "/home/local", - privateKey: nil, - nonce: 2, - txnOpts: txnOpts, - gasLimit: 1, + address: "", + nonce: 2, + txnOpts: txnOpts, + gasLimit: 1, }, want: txnOpts, expectedFatal: true, @@ -199,12 +192,11 @@ func Test_utils_GetTxnOpts(t *testing.T) { { name: "Test 4: When there is an error in getting nonce", args: args{ - path: "/home/local", - privateKey: privateKey, - nonce: 2, - nonceErr: errors.New("nonce error"), - txnOpts: txnOpts, - gasLimit: 1, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", + nonce: 2, + nonceErr: errors.New("nonce error"), + txnOpts: txnOpts, + gasLimit: 1, }, want: txnOpts, expectedFatal: true, @@ -212,8 +204,7 @@ func Test_utils_GetTxnOpts(t *testing.T) { { name: "Test 5: When there is an error in getting transactor", args: args{ - path: "/home/local", - privateKey: privateKey, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", nonce: 2, txnOpts: txnOpts, txnOptsErr: errors.New("transactor error"), @@ -225,8 +216,7 @@ func Test_utils_GetTxnOpts(t *testing.T) { { name: "Test 6: When there is an error in getting gasLimit", args: args{ - path: "/home/local", - privateKey: privateKey, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", nonce: 2, txnOpts: txnOpts, gasLimitErr: errors.New("gasLimit error"), @@ -235,10 +225,9 @@ func Test_utils_GetTxnOpts(t *testing.T) { expectedFatal: false, }, { - name: "Test 6: When there is an rpc error in getting gasLimit", + name: "Test 7: When there is an rpc error in getting gasLimit", args: args{ - path: "/home/local", - privateKey: privateKey, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", nonce: 2, txnOpts: txnOpts, gasLimitErr: errors.New("504 gateway error"), @@ -250,10 +239,9 @@ func Test_utils_GetTxnOpts(t *testing.T) { expectedFatal: false, }, { - name: "Test 7: When there is an rpc error in getting gasLimit and than error in getting latest header", + name: "Test 8: When there is an rpc error in getting gasLimit and than error in getting latest header", args: args{ - path: "/home/local", - privateKey: privateKey, + address: "0x57Baf83BAD5bee0F7F44d84669A50C35c57E3576", nonce: 2, txnOpts: txnOpts, gasLimitErr: errors.New("504 gateway error"), @@ -267,32 +255,42 @@ func Test_utils_GetTxnOpts(t *testing.T) { }, } - defer func() { log.ExitFunc = nil }() - var fatal bool - log.ExitFunc = func(int) { fatal = true } + originalExitFunc := log.ExitFunc // Preserve the original ExitFunc + defer func() { log.ExitFunc = originalExitFunc }() // Ensure it's reset after tests for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + fatalOccurred := false + + // Override log.ExitFunc to induce a panic for testing the fatal scenario + log.ExitFunc = func(int) { panic("log.Fatal called") } + + var account types.Account + accountManager := accounts.NewAccountManager("test_accounts") + if tt.args.address != "" { + account = accounts.InitAccountStruct(tt.args.address, "Test@123", accountManager) + } else { + account = types.Account{} + } + transactionData := types.TransactionOptions{ + Account: account, + } utilsMock := new(mocks.Utils) pathMock := new(mocks.PathUtils) bindMock := new(mocks.BindUtils) - accountsMock := new(mocks.AccountsUtils) clientMock := new(mocks.ClientUtils) gasMock := new(mocks.GasUtils) optionsPackageStruct := OptionsPackageStruct{ - UtilsInterface: utilsMock, - PathInterface: pathMock, - BindInterface: bindMock, - AccountsInterface: accountsMock, - ClientInterface: clientMock, - GasInterface: gasMock, + UtilsInterface: utilsMock, + PathInterface: pathMock, + BindInterface: bindMock, + ClientInterface: clientMock, + GasInterface: gasMock, } utils := StartRazor(optionsPackageStruct) - pathMock.On("GetDefaultPath").Return(tt.args.path, tt.args.pathErr) - accountsMock.On("GetPrivateKey", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(tt.args.privateKey, nil) clientMock.On("GetNonceAtWithRetry", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("common.Address")).Return(tt.args.nonce, tt.args.nonceErr) gasMock.On("GetGasPrice", mock.AnythingOfType("*ethclient.Client"), mock.AnythingOfType("types.Configurations")).Return(gasPrice) bindMock.On("NewKeyedTransactorWithChainID", mock.AnythingOfType("*ecdsa.PrivateKey"), mock.AnythingOfType("*big.Int")).Return(tt.args.txnOpts, tt.args.txnOptsErr) @@ -301,10 +299,29 @@ func Test_utils_GetTxnOpts(t *testing.T) { utilsMock.On("MultiplyFloatAndBigInt", mock.AnythingOfType("*big.Int"), mock.AnythingOfType("float64")).Return(big.NewInt(1)) clientMock.On("GetLatestBlockWithRetry", mock.AnythingOfType("*ethclient.Client")).Return(tt.args.latestHeader, tt.args.latestHeaderErr) - fatal = false + // Defer a function to recover from the panic and check if it matches the expectedFatal condition + defer func() { + if r := recover(); r != nil { + // A panic occurred, check if it was expected + if tt.expectedFatal { + // Panic (fatal) was expected and occurred, so this is correct + fatalOccurred = true + } else { + // Panic occurred but was not expected, fail the test + t.Errorf("Unexpected log.Fatal call") + } + } else { + // No panic occurred, check if it was expected + if tt.expectedFatal { + // Expected a fatal condition but it didn't occur, fail the test + t.Errorf("Expected log.Fatal call did not occur") + } + } + }() + got := utils.GetTxnOpts(transactionData) - if tt.expectedFatal { - assert.Equal(t, tt.expectedFatal, fatal) + if !tt.expectedFatal && fatalOccurred { + t.Fatalf("Test exited due to an unexpected fatal condition") } if got != tt.want { t.Errorf("GetTxnOpts() function, got = %v, want = %v", got, tt.want) diff --git a/utils/struct-utils.go b/utils/struct-utils.go index 4eb404a85..d22ac17d3 100644 --- a/utils/struct-utils.go +++ b/utils/struct-utils.go @@ -9,7 +9,6 @@ import ( "io/fs" "math/big" "os" - "razor/accounts" "razor/client" "razor/core" coretypes "razor/core/types" @@ -29,7 +28,6 @@ import ( ) var RPCTimeout int64 -var HTTPTimeout int64 func StartRazor(optionsPackageStruct OptionsPackageStruct) Utils { UtilsInterface = optionsPackageStruct.UtilsInterface @@ -42,7 +40,6 @@ func StartRazor(optionsPackageStruct OptionsPackageStruct) Utils { ABIInterface = optionsPackageStruct.ABIInterface PathInterface = optionsPackageStruct.PathInterface BindInterface = optionsPackageStruct.BindInterface - AccountsInterface = optionsPackageStruct.AccountsInterface BlockManagerInterface = optionsPackageStruct.BlockManagerInterface StakeManagerInterface = optionsPackageStruct.StakeManagerInterface AssetManagerInterface = optionsPackageStruct.AssetManagerInterface @@ -249,10 +246,6 @@ func (v VoteManagerStruct) GetSaltFromBlockchain(client *ethclient.Client) ([32] return returnedValues[0].Interface().([32]byte), nil } -func (a AccountsStruct) GetPrivateKey(address string, password string, keystorePath string) (*ecdsa.PrivateKey, error) { - return accounts.AccountUtilsInterface.GetPrivateKey(address, password, keystorePath) -} - func (b BlockManagerStruct) GetNumProposedBlocks(client *ethclient.Client, epoch uint32) (uint8, error) { blockManager, opts := UtilsInterface.GetBlockManagerWithOpts(client) returnedValues := InvokeFunctionWithTimeout(blockManager, "GetNumProposedBlocks", &opts, epoch)