Skip to content

Commit

Permalink
add chainId method no caching
Browse files Browse the repository at this point in the history
  • Loading branch information
augustbleeds committed May 2, 2024
1 parent 0b1d9f9 commit 3b761cc
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 16 deletions.
30 changes: 30 additions & 0 deletions relayer/pkg/starknet/chain_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type FinalizedBlock = starknetrpc.Block
type StarknetBatchBuilder interface {
RequestBlockByHash(h *felt.Felt) StarknetBatchBuilder
RequestBlockByNumber(id uint64) StarknetBatchBuilder
RequestChainId() StarknetBatchBuilder
// RequestLatestPendingBlock() (StarknetBatchBuilder)
RequestLatestBlockHashAndNumber() StarknetBatchBuilder
// RequestEventsByFilter(f starknetrpc.EventFilter) (StarknetBatchBuilder)
Expand All @@ -35,6 +36,15 @@ func NewBatchBuilder() StarknetBatchBuilder {
}
}

func (b *batchBuilder) RequestChainId() StarknetBatchBuilder {
b.args = append(b.args, gethrpc.BatchElem{
Method: "starknet_chainId",
Args: nil,
Result: new(string),
})
return b
}

func (b *batchBuilder) RequestBlockByHash(h *felt.Felt) StarknetBatchBuilder {
b.args = append(b.args, gethrpc.BatchElem{
Method: "starknet_getBlockWithTxs",
Expand Down Expand Up @@ -75,6 +85,7 @@ type StarknetChainClient interface {
BlockByHash(ctx context.Context, h *felt.Felt) (FinalizedBlock, error)
// only finalized blocks have numbers
BlockByNumber(ctx context.Context, id uint64) (FinalizedBlock, error)
ChainId(ctx context.Context) (string, error)
// only way to get the latest pending block (only 1 pending block exists at a time)
// LatestPendingBlock(ctx context.Context) (starknetrpc.PendingBlock, error)
// returns block number and block has of latest finalized block
Expand All @@ -87,6 +98,25 @@ type StarknetChainClient interface {

var _ StarknetChainClient = (*Client)(nil)

func (c *Client) ChainId(ctx context.Context) (string, error) {
// we do not use c.Provider.ChainID method because it caches
// the chainId after the first request

results, err := c.Batch(ctx, NewBatchBuilder().RequestChainId())

if err != nil {
return "", fmt.Errorf("error in ChainId: %w", err)
}

chainId, ok := results[0].Result.(*string)

if !ok {
return "", fmt.Errorf("expected type string block but found: %T", chainId)
}

return *chainId, nil
}

func (c *Client) BlockByHash(ctx context.Context, h *felt.Felt) (FinalizedBlock, error) {
if c.defaultTimeout != 0 {
var cancel context.CancelFunc
Expand Down
71 changes: 55 additions & 16 deletions relayer/pkg/starknet/chain_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

var (
myChainID = "SN_SEPOLIA"
myTimeout = 10 * time.Second
myTimeout = 100 * time.Second
blockNumber = 48719
blockHash, _ = new(felt.Felt).SetString("0x725407fcc3bd43e50884f50f1e0ef32aa9f814af3da475411934a7dbd4b41a")
blockResponse = []byte(`
Expand Down Expand Up @@ -71,6 +71,8 @@ var (
"0x725407fcc3bd43e50884f50f1e0ef32aa9f814af3da475411934a7dbd4b41a",
48719,
)
// hex-encoded value for "SN_SEPOLIA"
chainId = "0x534e5f5345504f4c4941"
)

func TestChainClient(t *testing.T) {
Expand All @@ -96,6 +98,8 @@ func TestChainClient(t *testing.T) {
out = []byte(`{"result": 1}`)
case "starknet_blockHashAndNumber":
out = []byte(fmt.Sprintf(`{"result": %s}`, blockHashAndNumberResponse))
case "starknet_chainId":
out = []byte(fmt.Sprintf(`{"result": "%s"}`, chainId))
default:
require.False(t, true, "unsupported RPC method %s", call.Method)
}
Expand All @@ -105,8 +109,24 @@ func TestChainClient(t *testing.T) {
errBatchMarshal := json.Unmarshal(req, &batchCall)
assert.NoError(t, errBatchMarshal)

response := fmt.Sprintf(`
// special case where we test chainId call
if len(batchCall) == 1 {
response := fmt.Sprintf(`
[
{ "jsonrpc": "2.0",
"id": %d,
"result": "%s"
}
]`, batchCall[0].Id, chainId)
out = []byte(response)
} else {

response := fmt.Sprintf(`
[
{ "jsonrpc": "2.0",
"id": %d,
"result": "%s"
},
{
"jsonrpc": "2.0",
"id": %d,
Expand All @@ -122,12 +142,15 @@ func TestChainClient(t *testing.T) {
"id": %d,
"result": %s
}
]`, batchCall[0].Id, blockResponse,
batchCall[1].Id, blockResponse,
batchCall[2].Id, blockHashAndNumberResponse,
)
]`, batchCall[0].Id, chainId,
batchCall[1].Id, blockResponse,
batchCall[2].Id, blockResponse,
batchCall[3].Id, blockHashAndNumberResponse,
)

out = []byte(response)
out = []byte(response)

}

}

Expand Down Expand Up @@ -160,39 +183,55 @@ func TestChainClient(t *testing.T) {
assert.Equal(t, uint64(blockNumber), output.BlockNumber)
})

t.Run("get ChainId", func(t *testing.T) {
output, err := client.ChainId(context.TODO())
require.NoError(t, err)
assert.Equal(t, chainId, output)
})

t.Run("get Batch", func(t *testing.T) {
builder := NewBatchBuilder()
builder.
RequestChainId().
RequestBlockByHash(blockHash).
RequestBlockByNumber(uint64(blockNumber)).
RequestLatestBlockHashAndNumber()

results, err := client.Batch(context.TODO(), builder)
require.NoError(t, err)

assert.Equal(t, 3, len(results))
assert.Equal(t, 4, len(results))

t.Run("gets BlockByHash in Batch", func(t *testing.T) {
assert.Equal(t, "starknet_getBlockWithTxs", results[0].Method)
t.Run("gets ChainId in Batch", func(t *testing.T) {
assert.Equal(t, "starknet_chainId", results[0].Method)
assert.Nil(t, results[0].Error)
block, ok := results[0].Result.(*FinalizedBlock)
id, ok := results[0].Result.(*string)
assert.True(t, ok)
assert.Equal(t, blockHash, block.BlockHash)
fmt.Println(id)
assert.Equal(t, chainId, *id)
})

t.Run("gets BlockByNumber in Batch", func(t *testing.T) {
t.Run("gets BlockByHash in Batch", func(t *testing.T) {
assert.Equal(t, "starknet_getBlockWithTxs", results[1].Method)
assert.Nil(t, results[1].Error)
block, ok := results[1].Result.(*FinalizedBlock)
assert.True(t, ok)
assert.Equal(t, blockHash, block.BlockHash)
})

t.Run("gets BlockByNumber in Batch", func(t *testing.T) {
assert.Equal(t, "starknet_getBlockWithTxs", results[2].Method)
assert.Nil(t, results[2].Error)
block, ok := results[2].Result.(*FinalizedBlock)
assert.True(t, ok)
assert.Equal(t, uint64(blockNumber), block.BlockNumber)

})

t.Run("gets LatestBlockHashAndNumber in Batch", func(t *testing.T) {
assert.Equal(t, "starknet_blockHashAndNumber", results[2].Method)
assert.Nil(t, results[2].Error)
info, ok := results[2].Result.(*starknetrpc.BlockHashAndNumberOutput)
assert.Equal(t, "starknet_blockHashAndNumber", results[3].Method)
assert.Nil(t, results[3].Error)
info, ok := results[3].Result.(*starknetrpc.BlockHashAndNumberOutput)
assert.True(t, ok)
assert.Equal(t, blockHash, info.BlockHash)
assert.Equal(t, uint64(blockNumber), info.BlockNumber)
Expand Down

0 comments on commit 3b761cc

Please sign in to comment.