4
4
from collections import defaultdict
5
5
from dataclasses import dataclass
6
6
from hashlib import blake2b
7
- from typing import Optional , Any , Union , Callable , Awaitable , cast
7
+ from typing import Optional , Any , Union , Callable , Awaitable , cast , TYPE_CHECKING
8
8
9
9
from bt_decode import PortableRegistry , decode as decode_by_type_string , MetadataV15
10
10
from async_property import async_property
19
19
BlockNotFound ,
20
20
)
21
21
from substrateinterface .storage import StorageKey
22
- import websockets
22
+ from websockets .asyncio .client import connect
23
+ from websockets .exceptions import ConnectionClosed
24
+
25
+ if TYPE_CHECKING :
26
+ from websockets .asyncio .client import ClientConnection
23
27
24
28
ResultHandler = Callable [[dict , Any ], Awaitable [tuple [dict , bool ]]]
25
29
@@ -433,7 +437,7 @@ def add_item(
433
437
self .block_hashes [block_hash ] = runtime
434
438
435
439
def retrieve (
436
- self , block : Optional [int ], block_hash : Optional [str ]
440
+ self , block : Optional [int ] = None , block_hash : Optional [str ] = None
437
441
) -> Optional ["Runtime" ]:
438
442
if block is not None :
439
443
return self .blocks .get (block )
@@ -624,7 +628,7 @@ def __init__(
624
628
# TODO allow setting max concurrent connections and rpc subscriptions per connection
625
629
# TODO reconnection logic
626
630
self .ws_url = ws_url
627
- self .ws : Optional [websockets . WebSocketClientProtocol ] = None
631
+ self .ws : Optional ["ClientConnection" ] = None
628
632
self .id = 0
629
633
self .max_subscriptions = max_subscriptions
630
634
self .max_connections = max_connections
@@ -646,15 +650,12 @@ async def __aenter__(self):
646
650
self ._exit_task .cancel ()
647
651
if not self ._initialized :
648
652
self ._initialized = True
649
- await self ._connect ()
653
+ self .ws = await asyncio .wait_for (
654
+ connect (self .ws_url , ** self ._options ), timeout = 10
655
+ )
650
656
self ._receiving_task = asyncio .create_task (self ._start_receiving ())
651
657
return self
652
658
653
- async def _connect (self ):
654
- self .ws = await asyncio .wait_for (
655
- websockets .connect (self .ws_url , ** self ._options ), timeout = 10
656
- )
657
-
658
659
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
659
660
async with self ._lock :
660
661
self ._in_use -= 1
@@ -695,9 +696,7 @@ async def shutdown(self):
695
696
696
697
async def _recv (self ) -> None :
697
698
try :
698
- response = json .loads (
699
- await cast (websockets .WebSocketClientProtocol , self .ws ).recv ()
700
- )
699
+ response = json .loads (await self .ws .recv ())
701
700
async with self ._lock :
702
701
self ._open_subscriptions -= 1
703
702
if "id" in response :
@@ -706,7 +705,7 @@ async def _recv(self) -> None:
706
705
self ._received [response ["params" ]["subscription" ]] = response
707
706
else :
708
707
raise KeyError (response )
709
- except websockets . ConnectionClosed :
708
+ except ConnectionClosed :
710
709
raise
711
710
except KeyError as e :
712
711
raise e
@@ -717,7 +716,7 @@ async def _start_receiving(self):
717
716
await self ._recv ()
718
717
except asyncio .CancelledError :
719
718
pass
720
- except websockets . ConnectionClosed :
719
+ except ConnectionClosed :
721
720
# TODO try reconnect, but only if it's needed
722
721
raise
723
722
@@ -734,7 +733,7 @@ async def send(self, payload: dict) -> int:
734
733
try :
735
734
await self .ws .send (json .dumps ({** payload , ** {"id" : original_id }}))
736
735
return original_id
737
- except websockets . ConnectionClosed :
736
+ except ConnectionClosed :
738
737
raise
739
738
740
739
async def retrieve (self , item_id : int ) -> Optional [dict ]:
@@ -775,7 +774,6 @@ def __init__(
775
774
chain_endpoint ,
776
775
options = {
777
776
"max_size" : 2 ** 32 ,
778
- "read_limit" : 2 ** 16 ,
779
777
"write_limit" : 2 ** 16 ,
780
778
},
781
779
)
@@ -1135,7 +1133,7 @@ async def create_storage_key(
1135
1133
-------
1136
1134
StorageKey
1137
1135
"""
1138
- runtime = await self .init_runtime (block_hash = block_hash )
1136
+ await self .init_runtime (block_hash = block_hash )
1139
1137
1140
1138
return StorageKey .create_from_storage_function (
1141
1139
pallet ,
@@ -1555,7 +1553,7 @@ async def _process_response(
1555
1553
self ,
1556
1554
response : dict ,
1557
1555
subscription_id : Union [int , str ],
1558
- value_scale_type : Optional [str ],
1556
+ value_scale_type : Optional [str ] = None ,
1559
1557
storage_item : Optional [ScaleType ] = None ,
1560
1558
runtime : Optional [Runtime ] = None ,
1561
1559
result_handler : Optional [ResultHandler ] = None ,
0 commit comments