diff --git a/CHANGELOG.md b/CHANGELOG.md index 611d914e4..049d3ea72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ ## Next release - feat: fetch eth/strk price and sync strk gas price +- fix: contract 0 state diff fixed +- refactor(rpc): re-worked rpc tower server and added proper websocket support +- fix(network): added the FGW and gateway url to the chain config +- fix(block_hash): block hash mismatch on transaction with an empty signature - feat: declare v0, l1 handler support added - feat: strk gas price cli param added - fix(snos): added special address while closing block for SNOS diff --git a/Cargo.lock b/Cargo.lock index 5ee31c545..131f6b69f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5720,6 +5720,7 @@ dependencies = [ "anyhow", "blockifier", "lazy_static", + "log", "mp-utils", "primitive-types", "rstest 0.18.2", @@ -5729,6 +5730,7 @@ dependencies = [ "starknet-types-core", "starknet_api", "thiserror", + "url", ] [[package]] @@ -8932,6 +8934,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0d1e639be..f06545f50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,7 +184,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] } serde_yaml = { version = "0.9.34" } thiserror = "1.0" tokio = { version = "1.34", features = ["signal", "rt"] } -url = "2.4" +url = { version = "2.4", features = ["serde"] } rayon = "1.10" bincode = "1.3" prometheus = "0.13.4" diff --git a/configs/presets/devnet.yaml b/configs/presets/devnet.yaml index 46b051587..2ef5ab76d 100644 --- a/configs/presets/devnet.yaml +++ b/configs/presets/devnet.yaml @@ -1,5 +1,7 @@ chain_name: "Madara" chain_id: "MADARA_DEVNET" +feeder_gateway_url: "http://localhost:8080/feeder_gateway/" +gateway_url: "http://localhost:8080/gateway/" native_fee_token_address: "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d" parent_fee_token_address: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7" latest_protocol_version: "0.13.2" diff --git a/configs/presets/integration.yaml b/configs/presets/integration.yaml index 3285802c5..981a5ae14 100644 --- a/configs/presets/integration.yaml +++ b/configs/presets/integration.yaml @@ -1,5 +1,7 @@ chain_name: "Starknet Sepolia" chain_id: "SN_INTEGRATION_SEPOLIA" +feeder_gateway_url: "https://integration-sepolia.starknet.io/feeder_gateway/" +gateway_url: "https://integration-sepolia.starknet.io/gateway/" native_fee_token_address: "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d" parent_fee_token_address: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7" latest_protocol_version: "0.13.2" diff --git a/configs/presets/mainnet.yaml b/configs/presets/mainnet.yaml index 8fbd3dedc..716b6a85a 100644 --- a/configs/presets/mainnet.yaml +++ b/configs/presets/mainnet.yaml @@ -1,5 +1,7 @@ chain_name: "Starknet Mainnet" chain_id: "SN_MAIN" +feeder_gateway_url: "https://alpha-mainnet.starknet.io/feeder_gateway/" +gateway_url: "https://alpha-mainnet.starknet.io/gateway/" native_fee_token_address: "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d" parent_fee_token_address: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7" latest_protocol_version: "0.13.2" diff --git a/configs/presets/sepolia.yaml b/configs/presets/sepolia.yaml index 50d9bbfcc..f4aded9cf 100644 --- a/configs/presets/sepolia.yaml +++ b/configs/presets/sepolia.yaml @@ -1,5 +1,7 @@ chain_name: "Starknet Sepolia" chain_id: "SN_SEPOLIA" +feeder_gateway_url: "https://alpha-sepolia.starknet.io/feeder_gateway/" +gateway_url: "https://alpha-sepolia.starknet.io/gateway/" native_fee_token_address: "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d" parent_fee_token_address: "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7" latest_protocol_version: "0.13.2" diff --git a/crates/client/gateway/src/server/router.rs b/crates/client/gateway/src/server/router.rs index c0f25090f..eaa79e02d 100644 --- a/crates/client/gateway/src/server/router.rs +++ b/crates/client/gateway/src/server/router.rs @@ -22,9 +22,9 @@ pub(crate) async fn main_router( match (path.as_ref(), feeder_gateway_enable, gateway_enable) { ("health", _, _) => Ok(Response::new(Body::from("OK"))), (path, true, _) if path.starts_with("feeder_gateway/") => feeder_gateway_router(req, path, backend).await, - (path, _, true) if path.starts_with("feeder/") => gateway_router(req, path, add_transaction_provider).await, + (path, _, true) if path.starts_with("gateway/") => gateway_router(req, path, add_transaction_provider).await, (path, false, _) if path.starts_with("feeder_gateway/") => Ok(service_unavailable_response("Feeder Gateway")), - (path, _, false) if path.starts_with("feeder/") => Ok(service_unavailable_response("Feeder")), + (path, _, false) if path.starts_with("gateway/") => Ok(service_unavailable_response("Feeder")), _ => { log::debug!(target: "feeder_gateway", "Main router received invalid request: {path}"); Ok(not_found_response()) @@ -74,7 +74,7 @@ async fn gateway_router( add_transaction_provider: Arc, ) -> Result, Infallible> { match (req.method(), req.uri().path()) { - (&Method::POST, "feeder/add_transaction") => { + (&Method::POST, "gateway/add_transaction") => { Ok(handle_add_transaction(req, add_transaction_provider).await.unwrap_or_else(Into::into)) } _ => { diff --git a/crates/client/mempool/src/block_production.rs b/crates/client/mempool/src/block_production.rs index e69141073..edd54558e 100644 --- a/crates/client/mempool/src/block_production.rs +++ b/crates/client/mempool/src/block_production.rs @@ -2,7 +2,7 @@ use blockifier::blockifier::transaction_executor::{TransactionExecutor, VisitedSegmentsMapping}; use blockifier::bouncer::{Bouncer, BouncerWeights, BuiltinCount}; -use blockifier::state::cached_state::CommitmentStateDiff; +use blockifier::state::cached_state::StateMaps; use blockifier::state::state_api::StateReader; use blockifier::transaction::errors::TransactionExecutionError; use mc_block_import::{BlockImportError, BlockImporter}; @@ -19,8 +19,11 @@ use mp_state_update::{ }; use mp_transactions::TransactionWithHash; use mp_utils::graceful_shutdown; +use starknet_api::core::ContractAddress; use starknet_types_core::felt::Felt; use std::borrow::Cow; +use std::collections::hash_map; +use std::collections::HashMap; use std::collections::VecDeque; use std::mem; use std::sync::Arc; @@ -56,20 +59,56 @@ pub enum Error { Unexpected(Cow<'static, str>), } -fn csd_to_state_diff( +fn state_map_to_state_diff( backend: &MadaraBackend, on_top_of: &Option, - csd: &CommitmentStateDiff, + diff: StateMaps, ) -> Result { - let CommitmentStateDiff { - address_to_class_hash, - address_to_nonce, - storage_updates, - class_hash_to_compiled_class_hash, - } = csd; - - let (mut deployed_contracts, mut replaced_classes) = (Vec::new(), Vec::new()); - for (contract_address, new_class_hash) in address_to_class_hash { + let mut backing_map = HashMap::::default(); + let mut storage_diffs = Vec::::default(); + for ((address, key), value) in diff.storage { + match backing_map.entry(address) { + hash_map::Entry::Vacant(e) => { + e.insert(storage_diffs.len()); + storage_diffs.push(ContractStorageDiffItem { + address: address.to_felt(), + storage_entries: vec![StorageEntry { key: key.to_felt(), value }], + }); + } + hash_map::Entry::Occupied(e) => { + storage_diffs[*e.get()].storage_entries.push(StorageEntry { key: key.to_felt(), value }); + } + } + } + + let mut deprecated_declared_classes = Vec::default(); + for (class_hash, _) in diff.declared_contracts { + if !diff.compiled_class_hashes.contains_key(&class_hash) { + deprecated_declared_classes.push(class_hash.to_felt()); + } + } + + let declared_classes = diff + .compiled_class_hashes + .iter() + .map(|(class_hash, compiled_class_hash)| DeclaredClassItem { + class_hash: class_hash.to_felt(), + compiled_class_hash: compiled_class_hash.to_felt(), + }) + .collect(); + + let nonces = diff + .nonces + .into_iter() + .map(|(contract_address, nonce)| NonceUpdate { + contract_address: contract_address.to_felt(), + nonce: nonce.to_felt(), + }) + .collect(); + + let mut deployed_contracts = Vec::new(); + let mut replaced_classes = Vec::new(); + for (contract_address, new_class_hash) in diff.compiled_class_hashes { let replaced = if let Some(on_top_of) = on_top_of { backend.get_contract_class_hash_at(on_top_of, &contract_address.to_felt())?.is_some() } else { @@ -90,31 +129,10 @@ fn csd_to_state_diff( } Ok(StateDiff { - storage_diffs: storage_updates - .into_iter() - .map(|(address, storage_entries)| ContractStorageDiffItem { - address: address.to_felt(), - storage_entries: storage_entries - .into_iter() - .map(|(key, value)| StorageEntry { key: key.to_felt(), value: *value }) - .collect(), - }) - .collect(), - deprecated_declared_classes: vec![], - declared_classes: class_hash_to_compiled_class_hash - .iter() - .map(|(class_hash, compiled_class_hash)| DeclaredClassItem { - class_hash: class_hash.to_felt(), - compiled_class_hash: compiled_class_hash.to_felt(), - }) - .collect(), - nonces: address_to_nonce - .into_iter() - .map(|(contract_address, nonce)| NonceUpdate { - contract_address: contract_address.to_felt(), - nonce: nonce.to_felt(), - }) - .collect(), + storage_diffs, + deprecated_declared_classes, + declared_classes, + nonces, deployed_contracts, replaced_classes, }) @@ -150,13 +168,13 @@ fn finalize_execution_state( backend: &MadaraBackend, on_top_of: &Option, ) -> Result<(StateDiff, VisitedSegmentsMapping, BouncerWeights), Error> { - let csd = tx_executor + let state_map = tx_executor .block_state .as_mut() .expect(BLOCK_STATE_ACCESS_ERR) .to_state_diff() .map_err(TransactionExecutionError::StateError)?; - let state_update = csd_to_state_diff(backend, on_top_of, &csd.into())?; + let state_update = state_map_to_state_diff(backend, on_top_of, state_map)?; let visited_segments = get_visited_segments(tx_executor)?; diff --git a/crates/client/rpc/src/RPC.md b/crates/client/rpc/src/RPC.md new file mode 100644 index 000000000..f10d9504f --- /dev/null +++ b/crates/client/rpc/src/RPC.md @@ -0,0 +1,94 @@ +# RPC + +_This section consists of a brief overview of RPC handling architecture inside +of Madara, as its structure can be quite confusing at first._ + +## Properties + +Madara RPC has the folliwing properties: + +**Each RPC category is independent** and decoupled from the rest, so `trace` +methods exist in isolation from `read` methods for example. + +**RPC methods are versioned**. It is therefore possible for a user to call +_different versions_ of the same RPC method. This is mostly present for ease of +development of new RPC versions, but also serves to assure a level of backwards +compatibility. To select a specific version of an rpc method, you will need to +append `/rcp/v{version}` to the rpc url you are connecting to. + +**RPC versions are grouped under the `Starknet` struct**. This serves as a +common point of implementation for all RPC methods across all versions, and is +also the point of interaction between RPC methods and the node backend. + +> [!NOTE] +> All of this is regrouped as an RPC _service_. + +## Implementation details + +There are **two** main parts to the implementation of RPC methods in Madara. + +### Jsonrpsee implementation + +> [!NOTE] > `jsonrpsee` is a library developed by Parity which is used to implement JSON +> RPC APIs through simple macro logic. + +Each RPC version is defined under the `version` folder using the +`versioned_starknet_rpc` macro. This just serves to rename the trait it is +defined on and all jsonrpsee `#[method]` definitions to include the version +name. The latter is especially important as it avoids name clashes when merging +multiple `RpcModule`s from different versions together. + +#### Renaming + +```rust +#[versioned_starknet_rpc("V0_7_1)")] +trait yourTrait { + #[method(name = "foo")] + async fn foo(); +} +``` + +Will become + +```rust +#[jsonrpsee::proc_macros::rpc(server, namespace = "starknet")] +trait yourTraitV0_7_1 { + #[method(name = "V0_7_1_foo")] + async fn foo(); +} +``` + +### Implementation as a service + +> [!IMPORTANT] +> This is where the RPC server is set up and where RPC calls are actually +> parsed, validated, routed and handled. + +`RpcService` is responsible for starting the rpc service, and hence the rpc +server. This is done with tower in the following steps: + +- RPC apis are built and combined into a single `RpcModule` using + `versioned_rpc_api`, and all other configurations are loaded. + +- Request filtering middleware is set up. This includes host origin filtering + and CORS filtering. + +> [!NOTE] +> Rpc middleware will apply to both websocket and http rpc requests, which is +> why we do not apply versioning in the http middleware. + +- Request constraints are set, such as the maximum number of connections and + request / response size constraints. + +- Additional service layers are added on each rpc call inside `service_fn`. + These are composed into versioning, rate limiting (which is optional) and + metrics layers. Importantly, version filtering with `RpcMiddlewareServiceVersion` + will transforms rpc methods request with header `/rpc/v{version}` and a json rpc + body with a `{method}` field into the correct `starknet_{version}_{method}` rpc + method call, as this is how we version them internally with jsonrpsee. + +> [!NOTE] +> The `starknet` prefix comes from the secondary macro expansion of +> `#[rpc(server, namespace = "starknet)]` + +- Finally, the RPC service is added to tower as `RpcServiceBuilder`. Note that diff --git a/crates/client/rpc/src/lib.rs b/crates/client/rpc/src/lib.rs index f63ff9e8e..fbe64f795 100644 --- a/crates/client/rpc/src/lib.rs +++ b/crates/client/rpc/src/lib.rs @@ -4,7 +4,6 @@ mod constants; mod errors; -mod macros; pub mod providers; #[cfg(test)] pub mod test_utils; @@ -99,14 +98,25 @@ pub fn versioned_rpc_api( write: bool, trace: bool, internal: bool, + ws: bool, ) -> anyhow::Result> { let mut rpc_api = RpcModule::new(()); - merge_rpc_versions!( - rpc_api, starknet, read, write, trace, internal, - v0_7_1, // We can add new versions by adding the version module below - // , v0_8_0 (for example) - ); + if read { + rpc_api.merge(versions::v0_7_1::StarknetReadRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if write { + rpc_api.merge(versions::v0_7_1::StarknetWriteRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if trace { + rpc_api.merge(versions::v0_7_1::StarknetTraceRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if internal { + rpc_api.merge(versions::v0_7_1::MadaraWriteRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if ws { + // V0.8.0 ... + } Ok(rpc_api) } diff --git a/crates/client/rpc/src/macros.rs b/crates/client/rpc/src/macros.rs deleted file mode 100644 index cb46b06dc..000000000 --- a/crates/client/rpc/src/macros.rs +++ /dev/null @@ -1,21 +0,0 @@ -#[macro_export] -macro_rules! merge_rpc_versions { - ($rpc_api:expr, $starknet:expr, $read:expr, $write:expr, $trace:expr, $internal:expr, $($version:ident),+ $(,)?) => { - $( - paste::paste! { - if $read { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $write { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $trace { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $internal { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - } - )+ - }; -} diff --git a/crates/client/rpc/src/versions/v0_7_1/api.rs b/crates/client/rpc/src/versions/v0_7_1/api.rs index 40dcb202c..2c5c2aa45 100644 --- a/crates/client/rpc/src/versions/v0_7_1/api.rs +++ b/crates/client/rpc/src/versions/v0_7_1/api.rs @@ -1,5 +1,4 @@ use jsonrpsee::core::RpcResult; -use jsonrpsee::proc_macros::rpc; use starknet_core::types::{ BlockHashAndNumber, BlockId, BroadcastedDeclareTransaction, BroadcastedDeployAccountTransaction, BroadcastedInvokeTransaction, BroadcastedTransaction, ContractClass, DeclareTransactionResult, diff --git a/crates/client/sync/src/utils.rs b/crates/client/sync/src/utils.rs index bc1399635..04330745a 100644 --- a/crates/client/sync/src/utils.rs +++ b/crates/client/sync/src/utils.rs @@ -2,10 +2,12 @@ use starknet_types_core::felt::Felt; pub fn trim_hash(hash: &Felt) -> String { let hash_str = format!("{:#x}", hash); - let hash_len = hash_str.len(); - let prefix = &hash_str[..6 + 2]; - let suffix = &hash_str[hash_len - 6..]; - - format!("{}...{}", prefix, suffix) + if hash_str.len() <= 12 { + hash_str.to_string() + } else { + let prefix = &hash_str[..6]; + let suffix = &hash_str[hash_str.len() - 6..]; + format!("{}...{}", prefix, suffix) + } } diff --git a/crates/node/src/cli/chain_config_overrides.rs b/crates/node/src/cli/chain_config_overrides.rs index 01f6253da..adb765836 100644 --- a/crates/node/src/cli/chain_config_overrides.rs +++ b/crates/node/src/cli/chain_config_overrides.rs @@ -101,6 +101,8 @@ impl ChainConfigOverrideParams { Ok(ChainConfig { chain_name: chain_config_overrides.chain_name, chain_id: chain_config_overrides.chain_id, + feeder_gateway_url: chain_config.feeder_gateway_url, + gateway_url: chain_config.gateway_url, native_fee_token_address: chain_config_overrides.native_fee_token_address, parent_fee_token_address: chain_config_overrides.parent_fee_token_address, latest_protocol_version: chain_config_overrides.latest_protocol_version, diff --git a/crates/node/src/cli/mod.rs b/crates/node/src/cli/mod.rs index e162e03f8..db0aad629 100644 --- a/crates/node/src/cli/mod.rs +++ b/crates/node/src/cli/mod.rs @@ -24,7 +24,6 @@ use clap::ArgGroup; use mp_chain_config::ChainConfig; use std::path::PathBuf; use std::sync::Arc; -use url::Url; /// Madara: High performance Starknet sequencer/full-node. #[derive(Clone, Debug, clap::Parser)] @@ -202,23 +201,6 @@ pub enum NetworkType { } impl NetworkType { - pub fn uri(&self) -> &'static str { - match self { - NetworkType::Main => "https://alpha-mainnet.starknet.io", - NetworkType::Test => "https://alpha-sepolia.starknet.io", - NetworkType::Integration => "https://integration-sepolia.starknet.io", - NetworkType::Devnet => unreachable!("Gateway url isn't needed for a devnet sequencer"), - } - } - - pub fn gateway(&self) -> Url { - Url::parse(&format!("{}/gateway/", self.uri())).expect("Invalid uri") - } - - pub fn feeder_gateway(&self) -> Url { - Url::parse(&format!("{}/feeder_gateway/", self.uri())).expect("Invalid uri") - } - pub fn chain_id(&self) -> ChainId { match self { NetworkType::Main => ChainId::Mainnet, diff --git a/crates/node/src/cli/rpc.rs b/crates/node/src/cli/rpc.rs index 48d18d8a5..63ded0ea5 100644 --- a/crates/node/src/cli/rpc.rs +++ b/crates/node/src/cli/rpc.rs @@ -4,7 +4,6 @@ use std::num::NonZeroU32; use std::str::FromStr; use clap::ValueEnum; -use ip_network::IpNetwork; use jsonrpsee::server::BatchRequestConfig; /// Available RPC methods. @@ -99,21 +98,6 @@ pub struct RpcParams { #[arg(env = "MADARA_RPC_RATE_LIMIT", long)] pub rpc_rate_limit: Option, - /// Disable RPC rate limiting for certain ip addresses or ranges. - /// - /// Each IP address must be in the following notation: `1.2.3.4/24`. - #[arg(env = "MADARA_RPC_RATE_LIMIT_WHITELISTED_IPS", long, num_args = 1..)] - pub rpc_rate_limit_whitelisted_ips: Vec, - - /// Trust proxy headers for disable rate limiting. - /// - /// When using a reverse proxy setup, the real requester IP is usually added to the headers as `X-Real-IP` or `X-Forwarded-For`. - /// By default, the RPC server will not trust these headers. - /// - /// This is currently only useful for rate-limiting reasons. - #[arg(env = "MADARA_RPC_RATE_LIMIT_TRUST_PROXY_HEADERS", long)] - pub rpc_rate_limit_trust_proxy_headers: bool, - /// Set the maximum RPC request payload size for both HTTP and WebSockets in megabytes. #[arg(env = "MADARA_RPC_MAX_REQUEST_SIZE", long, default_value_t = RPC_DEFAULT_MAX_REQUEST_SIZE_MB)] pub rpc_max_request_size: u32, @@ -147,14 +131,25 @@ pub struct RpcParams { #[arg(env = "MADARA_RPC_MAX_BATCH_REQUEST_LEN", long, conflicts_with_all = &["rpc_disable_batch_requests"], value_name = "LEN")] pub rpc_max_batch_request_len: Option, - /// Specify browser *origins* allowed to access the HTTP & WebSocket RPC servers. + /// Specify browser *origins* allowed to access the HTTP & WebSocket RPC + /// servers. /// /// For most purposes, an origin can be thought of as just `protocol://domain`. - /// By default, only browser requests from localhost will work. + /// Default behavior depends on `rpc_external`: + /// + /// - If rpc_external is set, CORS will default to allow all incoming + /// addresses. + /// - If rpc_external is not set, CORS will default to allow only + /// connections from `localhost`. /// - /// This argument is a comma separated list of origins, or the special `all` value. + /// > If the rpcs are permissive, the same will be true for core, and + /// > vise-versa. /// - /// Learn more about CORS and web security at . + /// This argument is a comma separated list of origins, or the special `all` + /// value. + /// + /// Learn more about CORS and web security at + /// . #[arg(env = "MADARA_RPC_CORS", long, value_name = "ORIGINS")] pub rpc_cors: Option, } @@ -162,12 +157,16 @@ pub struct RpcParams { impl RpcParams { pub fn cors(&self) -> Option> { let cors = self.rpc_cors.clone().unwrap_or_else(|| { - Cors::List(vec![ - "http://localhost:*".into(), - "http://127.0.0.1:*".into(), - "https://localhost:*".into(), - "https://127.0.0.1:*".into(), - ]) + if self.rpc_external { + Cors::All + } else { + Cors::List(vec![ + "http://localhost:*".into(), + "http://127.0.0.1:*".into(), + "https://localhost:*".into(), + "https://127.0.0.1:*".into(), + ]) + } }); match cors { diff --git a/crates/node/src/cli/sync.rs b/crates/node/src/cli/sync.rs index 6d72efd56..0c4c79394 100644 --- a/crates/node/src/cli/sync.rs +++ b/crates/node/src/cli/sync.rs @@ -1,13 +1,12 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; +use mp_chain_config::ChainConfig; use starknet_api::core::ChainId; use mc_sync::fetch::fetchers::FetchConfig; use mp_utils::parsers::{parse_duration, parse_url}; use url::Url; -use crate::cli::NetworkType; - #[derive(Clone, Debug, clap::Args)] pub struct SyncParams { /// Disable the sync service. The sync service is responsible for listening for new blocks on starknet and ethereum. @@ -69,13 +68,13 @@ pub struct SyncParams { } impl SyncParams { - pub fn block_fetch_config(&self, chain_id: ChainId, network: NetworkType) -> FetchConfig { + pub fn block_fetch_config(&self, chain_id: ChainId, chain_config: Arc) -> FetchConfig { let (gateway, feeder_gateway) = match &self.gateway_url { Some(url) => ( - url.join("/gateway/").expect("Error parsing url (this should not panic)"), - url.join("/feeder_gateway/").expect("Error parsing url (this should not panic)"), + url.join("/gateway/").expect("Error parsing url"), + url.join("/feeder_gateway/").expect("Error parsing url"), ), - None => (network.gateway(), network.feeder_gateway()), + None => (chain_config.gateway_url.clone(), chain_config.feeder_gateway_url.clone()), }; let polling = if self.no_sync_polling { None } else { Some(self.sync_polling_interval) }; diff --git a/crates/node/src/main.rs b/crates/node/src/main.rs index b5965f101..3664014af 100644 --- a/crates/node/src/main.rs +++ b/crates/node/src/main.rs @@ -164,9 +164,6 @@ async fn main() -> anyhow::Result<()> { let sync_service = SyncService::new( &run_cmd.sync_params, Arc::clone(&chain_config), - run_cmd.network.context( - "You should provide a `--network` argument to ensure you're syncing from the right FGW", - )?, &db_service, importer, telemetry_service.new_handle(), @@ -175,24 +172,13 @@ async fn main() -> anyhow::Result<()> { .context("Initializing sync service")?; ( - ServiceGroup::default().with(sync_service), - // TODO(rate-limit): we may get rate limited with this unconfigured provider? - Arc::new(ForwardToProvider::new(SequencerGatewayProvider::new( - run_cmd - .network - .context( - "You should provide a `--network` argument to ensure you're syncing from the right gateway", - )? - .gateway(), - run_cmd - .network - .context( - "You should provide a `--network` argument to ensure you're syncing from the right FGW", - )? - .feeder_gateway(), - chain_config.chain_id.to_felt(), - ))), - ) + ServiceGroup::default().with(sync_service), + Arc::new(ForwardToProvider::new(SequencerGatewayProvider::new( + chain_config.gateway_url.clone(), + chain_config.feeder_gateway_url.clone(), + chain_config.chain_id.to_felt(), + ))), + ) } }; diff --git a/crates/node/src/service/rpc.rs b/crates/node/src/service/rpc.rs index 7254a437f..3f464dde9 100644 --- a/crates/node/src/service/rpc.rs +++ b/crates/node/src/service/rpc.rs @@ -46,8 +46,8 @@ impl RpcService { (true, false) } }; - let (read, write, trace, internal) = (rpcs, rpcs, rpcs, node_operator); - let starknet = Starknet::new(Arc::clone(db.backend()), chain_config.clone(), add_txs_method_provider.clone()); + let (read, write, trace, internal, ws) = (rpcs, rpcs, rpcs, node_operator, rpcs); + let starknet = Starknet::new(Arc::clone(db.backend()), chain_config.clone(), add_txs_method_provider); let metrics = RpcMetrics::register(metrics_handle)?; Ok(Self { @@ -59,12 +59,10 @@ impl RpcService { max_payload_out_mb: config.rpc_max_response_size, max_subs_per_conn: config.rpc_max_subscriptions_per_connection, message_buffer_capacity: config.rpc_message_buffer_capacity_per_connection, - rpc_api: versioned_rpc_api(&starknet, read, write, trace, internal)?, + rpc_api: versioned_rpc_api(&starknet, read, write, trace, internal, ws)?, metrics, cors: config.cors(), rate_limit: config.rpc_rate_limit, - rate_limit_whitelisted_ips: config.rpc_rate_limit_whitelisted_ips.clone(), - rate_limit_trust_proxy_headers: config.rpc_rate_limit_trust_proxy_headers, }), server_handle: None, }) diff --git a/crates/node/src/service/rpc/middleware.rs b/crates/node/src/service/rpc/middleware.rs index 75a6a4d7e..f92f9e2f6 100644 --- a/crates/node/src/service/rpc/middleware.rs +++ b/crates/node/src/service/rpc/middleware.rs @@ -1,10 +1,7 @@ //! JSON-RPC specific middleware. -use std::future::Future; use std::num::NonZeroU32; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use futures::future::{BoxFuture, FutureExt}; @@ -12,128 +9,158 @@ use governor::clock::{Clock, DefaultClock, QuantaClock}; use governor::middleware::NoOpMiddleware; use governor::state::{InMemoryState, NotKeyed}; use governor::{Jitter, Quota, RateLimiter}; -use hyper::{Body, Response}; use jsonrpsee::server::middleware::rpc::RpcServiceT; -use jsonrpsee::types::{ErrorObject, Request}; -use jsonrpsee::MethodResponse; -use serde_json::{json, Value}; -use tower::{Layer, Service}; -use mp_chain_config::{RpcVersion, RpcVersionError}; +use mp_chain_config::RpcVersion; -pub use super::metrics::{Metrics, RpcMetrics}; +pub use super::metrics::Metrics; /// Rate limit middleware #[derive(Debug, Clone)] pub struct RateLimit { - pub(crate) inner: Arc>, + pub(crate) limiter: Arc>, pub(crate) clock: QuantaClock, } impl RateLimit { pub fn new(max_burst: NonZeroU32) -> Self { let clock = QuantaClock::default(); - Self { inner: Arc::new(RateLimiter::direct_with_clock(Quota::per_minute(max_burst), &clock)), clock } + Self { limiter: Arc::new(RateLimiter::direct_with_clock(Quota::per_minute(max_burst), &clock)), clock } } } const MAX_JITTER: Duration = Duration::from_millis(50); const MAX_RETRIES: usize = 10; -#[derive(Debug, Clone, Default)] -pub struct MiddlewareLayer { - rate_limit: Option, - metrics: Option, +#[derive(Debug, Clone)] +pub struct RpcMiddlewareLayerRateLimit { + rate_limit: RateLimit, } -impl MiddlewareLayer { - pub fn new() -> Self { - Self::default() +impl RpcMiddlewareLayerRateLimit { + pub fn new(n: NonZeroU32) -> Self { + Self { rate_limit: RateLimit::new(n) } } +} + +impl tower::Layer for RpcMiddlewareLayerRateLimit { + type Service = RpcMiddlewareServiceRateLimit; + + fn layer(&self, inner: S) -> Self::Service { + RpcMiddlewareServiceRateLimit { inner, rate_limit: self.rate_limit.clone() } + } +} + +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceRateLimit { + inner: S, + rate_limit: RateLimit, +} + +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceRateLimit +where + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, +{ + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; + + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); + let rate_limit = self.rate_limit.clone(); + + async move { + let mut attempts = 0; + let jitter = Jitter::up_to(MAX_JITTER); + let mut rate_limited = false; + + loop { + if attempts >= MAX_RETRIES { + return jsonrpsee::MethodResponse::error( + req.id, + jsonrpsee::types::ErrorObject::owned(-32099, "RPC rate limit exceeded", None::<()>), + ); + } + + if let Err(rejected) = rate_limit.limiter.check() { + tokio::time::sleep(jitter + rejected.wait_time_from(rate_limit.clock.now())).await; + rate_limited = true; + } else { + break; + } + + attempts += 1; + } - /// Enable new rate limit middleware enforced per minute. - pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self { - Self { rate_limit: Some(RateLimit::new(n)), metrics: self.metrics } + // This should be ok as a way to flag rate limited requests as the + // JSON RPC spec discourages the use of NULL as an id in a _request_ + // since it is used for _responses_ with an unknown id. + if rate_limited { + req.id = jsonrpsee::types::Id::Null; + } + + inner.call(req).await + } + .boxed() } +} +#[derive(Debug, Clone)] +pub struct RpcMiddlewareLayerMetrics { + metrics: Metrics, +} + +impl RpcMiddlewareLayerMetrics { /// Enable metrics middleware. - pub fn with_metrics(self, metrics: Metrics) -> Self { - Self { rate_limit: self.rate_limit, metrics: Some(metrics) } + pub fn new(metrics: Metrics) -> Self { + Self { metrics } } /// Register a new websocket connection. pub fn ws_connect(&self) { - if let Some(m) = self.metrics.as_ref() { - m.ws_connect() - } + self.metrics.ws_connect() } /// Register that a websocket connection was closed. pub fn ws_disconnect(&self, now: Instant) { - if let Some(m) = self.metrics.as_ref() { - m.ws_disconnect(now) - } + self.metrics.ws_disconnect(now) } } -impl tower::Layer for MiddlewareLayer { - type Service = Middleware; +impl tower::Layer for RpcMiddlewareLayerMetrics { + type Service = RpcMiddlewareServiceMetrics; - fn layer(&self, service: S) -> Self::Service { - Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() } + fn layer(&self, inner: S) -> Self::Service { + RpcMiddlewareServiceMetrics { inner, metrics: self.metrics.clone() } } } -pub struct Middleware { - service: S, - rate_limit: Option, - metrics: Option, +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceMetrics { + inner: S, + metrics: Metrics, } -impl<'a, S> RpcServiceT<'a> for Middleware +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceMetrics where - S: Send + Sync + RpcServiceT<'a> + Clone + 'static, + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; - fn call(&self, req: Request<'a>) -> Self::Future { - let now = Instant::now(); - - if let Some(m) = self.metrics.as_ref() { - m.on_call(&req) - } - - let service = self.service.clone(); - let rate_limit = self.rate_limit.clone(); + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); let metrics = self.metrics.clone(); async move { - let mut is_rate_limited = false; - - if let Some(limit) = rate_limit.as_ref() { - let mut attempts = 0; - let jitter = Jitter::up_to(MAX_JITTER); - - loop { - if attempts >= MAX_RETRIES { - return MethodResponse::error( - req.id, - ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>), - ); - } - - if let Err(rejected) = limit.inner.check() { - tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now())).await; - } else { - break; - } - - is_rate_limited = true; - attempts += 1; - } - } + let is_rate_limited = if matches!(req.id, jsonrpsee::types::params::Id::Null) { + req.id = jsonrpsee::types::params::Id::Number(1); + true + } else { + false + }; + + let now = std::time::Instant::now(); - let rp = service.call(req.clone()).await; + metrics.on_call(&req); + let rp = inner.call(req.clone()).await; let method = req.method_name(); let status = rp.as_error_code().unwrap_or(200); @@ -149,9 +176,7 @@ where "{method} {status} {res_len} - {response_time:?}", ); - if let Some(m) = metrics.as_ref() { - m.on_response(&req, &rp, is_rate_limited, now) - } + metrics.on_response(&req, &rp, is_rate_limited, now); rp } @@ -159,139 +184,60 @@ where } } -#[derive(Clone)] -pub struct VersionMiddleware { +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceVersion { inner: S, + path: String, } -#[derive(thiserror::Error, Debug)] -enum VersionMiddlewareError { - #[error("Failed to read request body: {0}")] - BodyReadError(#[from] hyper::Error), - #[error("Failed to parse JSON: {0}")] - JsonParseError(#[from] serde_json::Error), - #[error("Invalid URL format")] - InvalidUrlFormat, - #[error("Invalid version specified")] - InvalidVersion, - #[error("Unsupported version specified")] - UnsupportedVersion, - #[error("Invalid method format. Namespace required: {0}")] - InvalidMethodFormat(String), - #[error("Missing method in RPC request")] - MissingMethod, -} - -impl From for VersionMiddlewareError { - fn from(e: RpcVersionError) -> Self { - match e { - RpcVersionError::InvalidNumber(_) => Self::InvalidVersion, - RpcVersionError::InvalidPathSupplied => Self::InvalidUrlFormat, - RpcVersionError::InvalidVersion => Self::InvalidVersion, - RpcVersionError::TooManyComponents(_) => Self::InvalidVersion, - RpcVersionError::UnsupportedVersion => Self::UnsupportedVersion, - } - } -} - -impl VersionMiddleware { - pub fn new(inner: S) -> Self { - Self { inner } - } -} - -#[derive(Clone)] -pub struct VersionMiddlewareLayer; - -impl Layer for VersionMiddlewareLayer { - type Service = VersionMiddleware; - - fn layer(&self, inner: S) -> Self::Service { - VersionMiddleware::new(inner) +impl RpcMiddlewareServiceVersion { + pub fn new(inner: S, path: String) -> Self { + Self { inner, path } } } -impl Service> for VersionMiddleware +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceVersion where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, { - type Response = S::Response; - type Error = S::Error; - type Future = Pin> + Send + 'static>>; + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); + let path = self.path.clone(); - fn call(&mut self, mut req: hyper::Request) -> Self::Future { - let mut inner = self.inner.clone(); - - Box::pin(async move { - match add_rpc_version_to_method(&mut req).await { - Ok(()) => inner.call(req).await, - Err(e) => { - let error = match e { - VersionMiddlewareError::InvalidUrlFormat => { - ErrorObject::owned(-32600, "Invalid URL format. Use /rpc/v{version}", None::<()>) - } - VersionMiddlewareError::InvalidVersion => { - ErrorObject::owned(-32600, "Invalid RPC version specified", None::<()>) - } - VersionMiddlewareError::UnsupportedVersion => { - ErrorObject::owned(-32601, "Unsupported RPC version specified", None::<()>) - } - _ => ErrorObject::owned(-32603, "Internal error", None::<()>), - }; - - let body = json!({ - "jsonrpc": "2.0", - "error": error, - "id": null - }) - .to_string(); - - Ok(Response::builder() - .header("Content-Type", "application/json") - .body(Body::from(body)) - .unwrap_or_else(|_| Response::new(Body::from("Internal server error")))) - } + async move { + if req.method == "rpc_methods" { + return inner.call(req).await; } - }) - } -} -async fn add_rpc_version_to_method(req: &mut hyper::Request) -> Result<(), VersionMiddlewareError> { - let path = req.uri().path().to_string(); - let version = RpcVersion::from_request_path(&path)?; - - let whole_body = hyper::body::to_bytes(req.body_mut()).await?; - let json: Value = serde_json::from_slice(&whole_body)?; - - // in case of batched requests, the request is an array of JSON-RPC requests - let mut batched_request = false; - let mut items = if let Value::Array(items) = json { - batched_request = true; - items - } else { - vec![json] - }; - - for item in items.iter_mut() { - if let Some(method) = item.get_mut("method").as_deref().and_then(Value::as_str) { - let new_method = if let Some((prefix, suffix)) = method.split_once('_') { - format!("{}_{}_{}", prefix, version.name(), suffix) - } else { - return Err(VersionMiddlewareError::InvalidMethodFormat(method.to_string())); + let Ok(version) = RpcVersion::from_request_path(&path) else { + return jsonrpsee::MethodResponse::error( + req.id, + jsonrpsee::types::ErrorObject::owned( + jsonrpsee::types::error::PARSE_ERROR_CODE, + jsonrpsee::types::error::PARSE_ERROR_MSG, + None::<()>, + ), + ); }; - item["method"] = Value::String(new_method); - } else { - return Err(VersionMiddlewareError::MissingMethod); - } - } - let response = if batched_request { serde_json::to_vec(&items)? } else { serde_json::to_vec(&items[0])? }; - *req.body_mut() = Body::from(response); + let Some(method_without_namespace) = req.method.strip_prefix("starknet_") else { + return jsonrpsee::MethodResponse::error( + req.id(), + jsonrpsee::types::ErrorObject::owned( + jsonrpsee::types::error::METHOD_NOT_FOUND_CODE, + jsonrpsee::types::error::METHOD_NOT_FOUND_MSG, + Some(req.method_name()), + ), + ); + }; - Ok(()) + let method_new = format!("starknet_{}_{}", version.name(), method_without_namespace); + req.method = jsonrpsee::core::Cow::from(method_new); + + inner.call(req).await + } + .boxed() + } } diff --git a/crates/node/src/service/rpc/server.rs b/crates/node/src/service/rpc/server.rs index 379244203..4ab03657a 100644 --- a/crates/node/src/service/rpc/server.rs +++ b/crates/node/src/service/rpc/server.rs @@ -2,31 +2,20 @@ #![allow(clippy::borrow_interior_mutable_const)] use std::convert::Infallible; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use std::num::NonZeroU32; -use std::str::FromStr; use std::time::Duration; use anyhow::Context; -use forwarded_header_value::ForwardedHeaderValue; -use hyper::header::{HeaderName, HeaderValue}; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, StatusCode}; -use ip_network::IpNetwork; -use jsonrpsee::core::id_providers::RandomStringIdProvider; -use jsonrpsee::server::middleware::http::HostFilterLayer; -use jsonrpsee::server::middleware::rpc::RpcServiceBuilder; -use jsonrpsee::server::{stop_channel, ws, BatchRequestConfig, PingConfig, StopHandle, TowerServiceBuilder}; -use jsonrpsee::{Methods, RpcModule}; -use tokio::net::TcpListener; use tokio::task::JoinSet; use tower::Service; -use tower_http::cors::{AllowOrigin, CorsLayer}; use mp_utils::wait_or_graceful_shutdown; -use super::middleware::{Metrics, MiddlewareLayer, RpcMetrics, VersionMiddlewareLayer}; +use crate::service::rpc::middleware::{RpcMiddlewareLayerRateLimit, RpcMiddlewareServiceVersion}; + +use super::metrics::RpcMetrics; +use super::middleware::{Metrics, RpcMiddlewareLayerMetrics}; const MEGABYTE: u32 = 1024 * 1024; @@ -41,23 +30,19 @@ pub struct ServerConfig { pub max_payload_out_mb: u32, pub metrics: RpcMetrics, pub message_buffer_capacity: u32, - pub rpc_api: RpcModule<()>, + pub rpc_api: jsonrpsee::RpcModule<()>, /// Batch request config. - pub batch_config: BatchRequestConfig, + pub batch_config: jsonrpsee::server::BatchRequestConfig, /// Rate limit calls per minute. pub rate_limit: Option, - /// Disable rate limit for certain ips. - pub rate_limit_whitelisted_ips: Vec, - /// Trust proxy headers for rate limiting. - pub rate_limit_trust_proxy_headers: bool, } #[derive(Debug, Clone)] struct PerConnection { - methods: Methods, - stop_handle: StopHandle, + methods: jsonrpsee::Methods, + stop_handle: jsonrpsee::server::StopHandle, metrics: RpcMetrics, - service_builder: TowerServiceBuilder, + service_builder: jsonrpsee::server::TowerServiceBuilder, } /// Start RPC server listening on given address. @@ -77,102 +62,77 @@ pub async fn start_server( message_buffer_capacity, rpc_api, rate_limit, - rate_limit_whitelisted_ips, - rate_limit_trust_proxy_headers, } = config; - let std_listener = TcpListener::bind(addr) + let listener = tokio::net::TcpListener::bind(addr) .await - .and_then(|a| a.into_std()) - .with_context(|| format!("binding to address: {addr}"))?; - let local_addr = std_listener.local_addr().ok(); - let host_filter = host_filtering(cors.is_some(), local_addr); + .with_context(|| format!("Binding TCP listener to address: {addr}"))?; + let local_addr = listener.local_addr().context("Failed to retrieve local address after binding TCP listener")?; + + let ping_config = jsonrpsee::server::PingConfig::new() + .ping_interval(Duration::from_secs(30)) + .inactive_limit(Duration::from_secs(60)) + .max_failures(3); let http_middleware = tower::ServiceBuilder::new() - .option_layer(host_filter) - // Proxy `GET /health` requests to internal `system_health` method. - // .layer(ProxyGetRequestLayer::new("/health", "system_health")?) - .layer(VersionMiddlewareLayer) - .layer(try_into_cors(cors.as_ref())?); + .option_layer(host_filtering(cors.is_some(), local_addr)) + .layer(try_into_cors(cors.as_ref())?); let builder = jsonrpsee::server::Server::builder() .max_request_body_size(max_payload_in_mb.saturating_mul(MEGABYTE)) .max_response_body_size(max_payload_out_mb.saturating_mul(MEGABYTE)) .max_connections(max_connections) .max_subscriptions_per_connection(max_subs_per_conn) - .enable_ws_ping( - PingConfig::new() - .ping_interval(Duration::from_secs(30)) - .inactive_limit(Duration::from_secs(60)) - .max_failures(3), - ) - .set_http_middleware(http_middleware) + .enable_ws_ping(ping_config) .set_message_buffer_capacity(message_buffer_capacity) .set_batch_request_config(batch_config) - .set_id_provider(RandomStringIdProvider::new(16)); + .set_http_middleware(http_middleware) + .set_id_provider(jsonrpsee::server::RandomStringIdProvider::new(16)); - let (stop_handle, server_handle) = stop_channel(); + let (stop_handle, server_handle) = jsonrpsee::server::stop_channel(); let cfg = PerConnection { methods: build_rpc_api(rpc_api).into(), - service_builder: builder.to_service_builder(), - metrics, stop_handle: stop_handle.clone(), + metrics, + service_builder: builder.to_service_builder(), }; - let make_service = make_service_fn(move |addr: &AddrStream| { + let make_service = hyper::service::make_service_fn(move |_| { let cfg = cfg.clone(); - let rate_limit_whitelisted_ips = rate_limit_whitelisted_ips.clone(); - let ip = addr.remote_addr().ip(); async move { let cfg = cfg.clone(); - let rate_limit_whitelisted_ips = rate_limit_whitelisted_ips.clone(); - - Ok::<_, Infallible>(service_fn(move |req| { - let proxy_ip = if rate_limit_trust_proxy_headers { get_proxy_ip(&req) } else { None }; - - let rate_limit_cfg = if rate_limit_whitelisted_ips - .iter() - .any(|ips| ips.contains(proxy_ip.unwrap_or(ip))) - { - log::debug!(target: "rpc", "ip={ip}, proxy_ip={:?} is trusted, disabling rate-limit", proxy_ip); - None - } else { - if !rate_limit_whitelisted_ips.is_empty() { - log::debug!(target: "rpc", "ip={ip}, proxy_ip={:?} is not trusted, rate-limit enabled", proxy_ip); - } - rate_limit - }; + Ok::<_, Infallible>(hyper::service::service_fn(move |req| { let PerConnection { service_builder, metrics, stop_handle, methods } = cfg.clone(); - let is_websocket = ws::is_upgrade_request(&req); + let is_websocket = jsonrpsee::server::ws::is_upgrade_request(&req); let transport_label = if is_websocket { "ws" } else { "http" }; + let path = req.uri().path().to_string(); + let metrics_layer = RpcMiddlewareLayerMetrics::new(Metrics::new(metrics, transport_label)); - let middleware_layer = match rate_limit_cfg { - None => MiddlewareLayer::new().with_metrics(Metrics::new(metrics, transport_label)), - Some(rate_limit) => MiddlewareLayer::new() - .with_metrics(Metrics::new(metrics, transport_label)) - .with_rate_limit_per_minute(rate_limit), - }; - - let rpc_middleware = RpcServiceBuilder::new().layer(middleware_layer.clone()); + let rpc_middleware = jsonrpsee::server::RpcServiceBuilder::new() + .layer_fn(move |service| RpcMiddlewareServiceVersion::new(service, path.clone())) + .option_layer(rate_limit.map(RpcMiddlewareLayerRateLimit::new)) + .layer(metrics_layer.clone()); let mut svc = service_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); async move { if req.uri().path() == "/health" { - Ok(Response::builder().status(StatusCode::OK).body(Body::from("OK"))?) + Ok(hyper::Response::builder().status(hyper::StatusCode::OK).body(hyper::Body::from("OK"))?) } else { if is_websocket { + // Utilize the session close future to know when the actual WebSocket + // session was closed. let on_disconnect = svc.on_session_closed(); // Spawn a task to handle when the connection is closed. tokio::spawn(async move { let now = std::time::Instant::now(); - middleware_layer.ws_connect(); + metrics_layer.ws_connect(); on_disconnect.await; - middleware_layer.ws_disconnect(now); + metrics_layer.ws_disconnect(now); }); } @@ -183,16 +143,12 @@ pub async fn start_server( } }); - let server = hyper::Server::from_tcp(std_listener) + let server = hyper::Server::from_tcp(listener.into_std()?) .with_context(|| format!("Creating hyper server at: {addr}"))? .serve(make_service); join_set.spawn(async move { - log::info!( - "📱 Running JSON-RPC server at {} (allowed origins={})", - local_addr.map_or_else(|| "unknown".to_string(), |a| a.to_string()), - format_cors(cors.as_ref()) - ); + log::info!("📱 Running JSON-RPC server at {} (allowed origins={})", local_addr, format_cors(cors.as_ref())); server .with_graceful_shutdown(async { wait_or_graceful_shutdown(stop_handle.shutdown()).await; @@ -204,31 +160,49 @@ pub async fn start_server( Ok(server_handle) } -const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); -const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip"); -const FORWARDED: HeaderName = HeaderName::from_static("forwarded"); - -pub(crate) fn host_filtering(enabled: bool, addr: Option) -> Option { - // If the local_addr failed, fallback to wildcard. - let port = addr.map_or("*".to_string(), |p| p.port().to_string()); - +// Copied from https://github.com/paritytech/polkadot-sdk/blob/a0aefc6b233ace0a82a8631d67b6854e6aeb014b/substrate/client/rpc-servers/src/utils.rs#L192 +pub(crate) fn host_filtering( + enabled: bool, + addr: SocketAddr, +) -> Option { if enabled { // NOTE: The listening addresses are whitelisted by default. - let hosts = [format!("localhost:{port}"), format!("127.0.0.1:{port}"), format!("[::1]:{port}")]; - Some(HostFilterLayer::new(hosts).expect("Invalid host filter")) + + let mut hosts = Vec::new(); + + if addr.is_ipv4() { + hosts.push(format!("localhost:{}", addr.port())); + hosts.push(format!("127.0.0.1:{}", addr.port())); + } else { + hosts.push(format!("[::1]:{}", addr.port())); + } + + Some(jsonrpsee::server::middleware::http::HostFilterLayer::new(hosts).expect("Valid hosts; qed")) } else { None } } -pub(crate) fn build_rpc_api(mut rpc_api: RpcModule) -> RpcModule { - let mut available_methods = rpc_api.method_names().collect::>(); +pub(crate) fn build_rpc_api(mut rpc_api: jsonrpsee::RpcModule) -> jsonrpsee::RpcModule { + let mut available_methods = rpc_api + .method_names() + .map(|name| { + let mut split = name.split("_"); + let namespace = split.next().expect("Should not be empty"); + let major = split.next().expect("Should not be empty"); + let minor = split.next().expect("Should not be empty"); + let patch = split.next().expect("Should not be empty"); + let method = split.next().expect("Should not be empty"); + + format!("rpc/{major}_{minor}_{patch}/{namespace}_{method}") + }) + .collect::>(); // The "rpc_methods" is defined below and we want it to be part of the reported methods. // The available methods will be prefixed by their version, example: - // * starknet_V0_7_1_blockNumber, - // * starknet_V0_8_0_blockNumber (...) - available_methods.push("rpc_methods"); + // * rpc/v0_7_1/starknet_blockNumber, + // * rpc/v0_8_0/starknet_blockNumber (...) + available_methods.push("rpc/rpc_methods".to_string()); available_methods.sort(); rpc_api @@ -242,16 +216,15 @@ pub(crate) fn build_rpc_api(mut rpc_api: RpcModule) rpc_api } -pub(crate) fn try_into_cors(maybe_cors: Option<&Vec>) -> anyhow::Result { +pub(crate) fn try_into_cors(maybe_cors: Option<&Vec>) -> anyhow::Result { if let Some(cors) = maybe_cors { let mut list = Vec::new(); for origin in cors { - list.push(HeaderValue::from_str(origin)?); + list.push(hyper::header::HeaderValue::from_str(origin)?); } - Ok(CorsLayer::new().allow_origin(AllowOrigin::list(list))) + Ok(tower_http::cors::CorsLayer::new().allow_origin(tower_http::cors::AllowOrigin::list(list))) } else { - // allow all cors - Ok(CorsLayer::permissive()) + Ok(tower_http::cors::CorsLayer::permissive()) } } @@ -262,38 +235,3 @@ pub(crate) fn format_cors(maybe_cors: Option<&Vec>) -> String { format!("{:?}", ["*"]) } } - -/// Extracts the IP addr from the HTTP request. -/// -/// It is extracted in the following order: -/// 1. `Forwarded` header. -/// 2. `X-Forwarded-For` header. -/// 3. `X-Real-Ip`. -pub(crate) fn get_proxy_ip(req: &Request) -> Option { - if let Some(ip) = req - .headers() - .get(&FORWARDED) - .and_then(|v| v.to_str().ok()) - .and_then(|v| ForwardedHeaderValue::from_forwarded(v).ok()) - .and_then(|v| v.remotest_forwarded_for_ip()) - { - return Some(ip); - } - - if let Some(ip) = req - .headers() - .get(&X_FORWARDED_FOR) - .and_then(|v| v.to_str().ok()) - .and_then(|v| ForwardedHeaderValue::from_x_forwarded_for(v).ok()) - .and_then(|v| v.remotest_forwarded_for_ip()) - { - return Some(ip); - } - - if let Some(ip) = req.headers().get(&X_REAL_IP).and_then(|v| v.to_str().ok()).and_then(|v| IpAddr::from_str(v).ok()) - { - return Some(ip); - } - - None -} diff --git a/crates/node/src/service/sync.rs b/crates/node/src/service/sync.rs index 5d0604b70..2ab8914ca 100644 --- a/crates/node/src/service/sync.rs +++ b/crates/node/src/service/sync.rs @@ -1,4 +1,4 @@ -use crate::cli::{NetworkType, SyncParams}; +use crate::cli::SyncParams; use anyhow::Context; use mc_block_import::BlockImporter; use mc_db::{DatabaseService, MadaraBackend}; @@ -26,14 +26,13 @@ impl SyncService { pub async fn new( config: &SyncParams, chain_config: Arc, - network: NetworkType, db: &DatabaseService, block_importer: Arc, telemetry: TelemetryHandle, ) -> anyhow::Result { - let fetch_config = config.block_fetch_config(chain_config.chain_id.clone(), network); + let fetch_config = config.block_fetch_config(chain_config.chain_id.clone(), chain_config.clone()); - log::info!("🛰️ Using feeder url: {} ", fetch_config.gateway.as_str()); + log::info!("🛰️ Using feeder gateway URL: {}", fetch_config.feeder_gateway.as_str()); Ok(Self { db_backend: Arc::clone(db.backend()), diff --git a/crates/primitives/block/src/header.rs b/crates/primitives/block/src/header.rs index ff178cff6..5207d56ce 100644 --- a/crates/primitives/block/src/header.rs +++ b/crates/primitives/block/src/header.rs @@ -210,6 +210,7 @@ impl Header { self.parent_block_hash, ]) } else { + // Based off https://github.com/starkware-libs/sequencer/blob/78ceca6aa230a63ca31f29f746fbb26d312fe381/crates/starknet_api/src/block_hash/block_hash_calculator.rs#L67 Poseidon::hash_array(&[ Felt::from_bytes_be_slice(b"STARKNET_BLOCK_HASH0"), Felt::from(self.block_number), diff --git a/crates/primitives/chain_config/Cargo.toml b/crates/primitives/chain_config/Cargo.toml index d373c4199..0df76791b 100644 --- a/crates/primitives/chain_config/Cargo.toml +++ b/crates/primitives/chain_config/Cargo.toml @@ -26,11 +26,13 @@ mp-utils.workspace = true # Other anyhow.workspace = true lazy_static.workspace = true +log.workspace = true primitive-types.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true serde_yaml.workspace = true thiserror.workspace = true +url.workspace = true [dev-dependencies] rstest.workspace = true diff --git a/crates/primitives/chain_config/src/chain_config.rs b/crates/primitives/chain_config/src/chain_config.rs index c55ec327f..3549a7b7c 100644 --- a/crates/primitives/chain_config/src/chain_config.rs +++ b/crates/primitives/chain_config/src/chain_config.rs @@ -23,6 +23,7 @@ use serde::de::{MapAccess, Visitor}; use serde::{Deserialize, Deserializer, Serialize}; use starknet_api::core::{ChainId, ContractAddress, PatriciaKey}; use starknet_types_core::felt::Felt; +use url::Url; use mp_utils::serde::{deserialize_duration, deserialize_private_key}; @@ -73,6 +74,10 @@ pub struct ChainConfig { pub chain_name: String, pub chain_id: ChainId, + // The Gateway URLs are the URLs of the endpoint that the node will use to sync blocks in full mode. + pub feeder_gateway_url: Url, + pub gateway_url: Url, + /// For starknet, this is the STRK ERC-20 contract on starknet. pub native_fee_token_address: ContractAddress, /// For starknet, this is the ETH ERC-20 contract on starknet. @@ -175,6 +180,8 @@ impl ChainConfig { Self { chain_name: "Starknet Mainnet".into(), chain_id: ChainId::Mainnet, + feeder_gateway_url: Url::parse("https://alpha-mainnet.starknet.io/feeder_gateway/").unwrap(), + gateway_url: Url::parse("https://alpha-mainnet.starknet.io/gateway/").unwrap(), native_fee_token_address: ContractAddress( PatriciaKey::try_from(Felt::from_hex_unchecked( "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d", @@ -236,6 +243,8 @@ impl ChainConfig { Self { chain_name: "Starknet Sepolia".into(), chain_id: ChainId::Sepolia, + feeder_gateway_url: Url::parse("https://alpha-sepolia.starknet.io/feeder_gateway/").unwrap(), + gateway_url: Url::parse("https://alpha-sepolia.starknet.io/gateway/").unwrap(), eth_core_contract_address: eth_core_contract_address::SEPOLIA_TESTNET.parse().expect("parsing a constant"), eth_gps_statement_verifier: eth_gps_statement_verifier::SEPOLIA_TESTNET .parse() @@ -248,6 +257,8 @@ impl ChainConfig { Self { chain_name: "Starknet Sepolia Integration".into(), chain_id: ChainId::IntegrationSepolia, + feeder_gateway_url: Url::parse("https://integration-sepolia.starknet.io/feeder_gateway/").unwrap(), + gateway_url: Url::parse("https://integration-sepolia.starknet.io/gateway/").unwrap(), eth_core_contract_address: eth_core_contract_address::SEPOLIA_INTEGRATION .parse() .expect("parsing a constant"), @@ -262,6 +273,8 @@ impl ChainConfig { Self { chain_name: "Madara".into(), chain_id: ChainId::Other("MADARA_DEVNET".into()), + feeder_gateway_url: Url::parse("http://localhost:8080/feeder_gateway/").unwrap(), + gateway_url: Url::parse("http://localhost:8080/gateway/").unwrap(), sequencer_address: Felt::from_hex_unchecked("0x123").try_into().unwrap(), ..ChainConfig::starknet_sepolia() } @@ -271,6 +284,8 @@ impl ChainConfig { Self { chain_name: "Test".into(), chain_id: ChainId::Other("MADARA_TEST".into()), + feeder_gateway_url: Url::parse("http://localhost:8080/feeder_gateway/").unwrap(), + gateway_url: Url::parse("http://localhost:8080/gateway/").unwrap(), // A random sequencer address for fee transfers to work in block production. sequencer_address: Felt::from_hex_unchecked( "0x211b748338b39fe8fa353819d457681aa50ac598a3db84cacdd6ece0a17e1f3", diff --git a/crates/primitives/chain_config/src/rpc_version.rs b/crates/primitives/chain_config/src/rpc_version.rs index f60bc51dd..f86c21d65 100644 --- a/crates/primitives/chain_config/src/rpc_version.rs +++ b/crates/primitives/chain_config/src/rpc_version.rs @@ -4,6 +4,7 @@ use std::str::FromStr; lazy_static::lazy_static! { pub static ref SUPPORTED_RPC_VERSIONS: Vec = vec![ RpcVersion::RPC_VERSION_0_7_1, + RpcVersion::RPC_VERSION_0_8_0, ]; } @@ -30,28 +31,38 @@ impl RpcVersion { } pub fn from_request_path(path: &str) -> Result { + log::debug!(target: "rpc_version", "extracting rpc version from request: {path}"); + let path = path.to_ascii_lowercase(); let parts: Vec<&str> = path.split('/').collect(); + log::debug!(target: "rpc_version", "version parts are: {parts:?}"); + // If we have an empty path or just "/", fallback to latest rpc version if parts.len() == 1 || (parts.len() == 2 && parts[1].is_empty()) { + log::debug!(target: "rpc_version", "no version, defaulting to latest"); return Ok(Self::RPC_VERSION_LATEST); } // Check if the path follows the correct format, i.e. /rpc/v[version]. // If not, fallback to the latest version if parts.len() != 3 || parts[1] != "rpc" || !parts[2].starts_with('v') { + log::debug!(target: "rpc_version", "invalid version format, defaulting to latest"); return Ok(Self::RPC_VERSION_LATEST); } + log::debug!(target: "rpc_version", "looking for matching version..."); let version_str = &parts[2][1..]; // without the 'v' prefix if let Ok(version) = RpcVersion::from_str(version_str) { if SUPPORTED_RPC_VERSIONS.contains(&version) { + log::debug!(target: "rpc_version", "found matching version: {version}"); Ok(version) } else { + log::debug!(target: "rpc_version", "no matching version"); Err(RpcVersionError::UnsupportedVersion) } } else { + log::debug!(target: "rpc_version", "invalid version format: {version_str}"); Err(RpcVersionError::InvalidVersion) } } @@ -69,6 +80,7 @@ impl RpcVersion { } pub const RPC_VERSION_0_7_1: RpcVersion = RpcVersion([0, 7, 1]); + pub const RPC_VERSION_0_8_0: RpcVersion = RpcVersion([0, 8, 0]); pub const RPC_VERSION_LATEST: RpcVersion = Self::RPC_VERSION_0_7_1; } diff --git a/crates/primitives/transactions/src/compute_hash.rs b/crates/primitives/transactions/src/compute_hash.rs index 7a9475c80..8cddefc38 100644 --- a/crates/primitives/transactions/src/compute_hash.rs +++ b/crates/primitives/transactions/src/compute_hash.rs @@ -22,6 +22,8 @@ const L1_HANDLER_PREFIX: Felt = Felt::from_hex_unchecked("0x6c315f68616e646c6572 const L1_GAS: &[u8] = b"L1_GAS"; const L2_GAS: &[u8] = b"L2_GAS"; +const PEDERSEN_EMPTY: Felt = + Felt::from_hex_unchecked("0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804"); impl Transaction { pub fn compute_hash(&self, chain_id: Felt, version: StarknetVersion, offset_version: bool) -> Felt { @@ -58,60 +60,53 @@ impl Transaction { /// computing the transaction commitent uses a hash value that combines /// the transaction hash with the array of signature values. pub fn compute_hash_with_signature(&self, tx_hash: Felt, starknet_version: StarknetVersion) -> Felt { - let include_signature = starknet_version >= StarknetVersion::V0_11_1; - - let leaf = match self { - Transaction::Invoke(tx) => { - // Include signatures for Invoke transactions or for all transactions - if starknet_version < StarknetVersion::V0_13_2 { - let signature_hash = tx.compute_hash_signature::(); - Pedersen::hash(&tx_hash, &signature_hash) - } else { - let elements: Vec = std::iter::once(tx_hash).chain(tx.signature().iter().copied()).collect(); - Poseidon::hash_array(&elements) - } - } - Transaction::Declare(tx) => { - if include_signature { - if starknet_version < StarknetVersion::V0_13_2 { - let signature_hash = tx.compute_hash_signature::(); - Pedersen::hash(&tx_hash, &signature_hash) - } else { - let elements: Vec = - std::iter::once(tx_hash).chain(tx.signature().iter().copied()).collect(); - Poseidon::hash_array(&elements) - } - } else { - let signature_hash = Pedersen::hash_array(&[]); - Pedersen::hash(&tx_hash, &signature_hash) - } - } - Transaction::DeployAccount(tx) => { - if include_signature { - if starknet_version < StarknetVersion::V0_13_2 { - let signature_hash = tx.compute_hash_signature::(); - Pedersen::hash(&tx_hash, &signature_hash) - } else { - let elements: Vec = - std::iter::once(tx_hash).chain(tx.signature().iter().copied()).collect(); - Poseidon::hash_array(&elements) - } - } else { - let signature_hash = Pedersen::hash_array(&[]); - Pedersen::hash(&tx_hash, &signature_hash) - } - } - _ => { - if starknet_version < StarknetVersion::V0_13_2 { - let signature_hash = Pedersen::hash_array(&[]); - Pedersen::hash(&tx_hash, &signature_hash) - } else { - Poseidon::hash_array(&[tx_hash, Felt::ZERO]) - } - } + if starknet_version < StarknetVersion::V0_11_1 { + self.compute_hash_with_signature_pre_v0_11_1(tx_hash) + } else if starknet_version < StarknetVersion::V0_13_2 { + self.compute_hash_with_signature_pre_v0_13_2(tx_hash) + } else { + self.compute_hash_with_signature_latest(tx_hash) + } + } + + fn compute_hash_with_signature_pre_v0_11_1(&self, tx_hash: Felt) -> Felt { + let signature_hash = match self { + Transaction::Invoke(tx) => tx.compute_hash_signature::(), + Transaction::Declare(_) + | Transaction::DeployAccount(_) + | Transaction::Deploy(_) + | Transaction::L1Handler(_) => PEDERSEN_EMPTY, + }; + + Pedersen::hash(&tx_hash, &signature_hash) + } + + fn compute_hash_with_signature_pre_v0_13_2(&self, tx_hash: Felt) -> Felt { + let signature_hash = match self { + Transaction::Invoke(tx) => tx.compute_hash_signature::(), + Transaction::Declare(tx) => tx.compute_hash_signature::(), + Transaction::DeployAccount(tx) => tx.compute_hash_signature::(), + Transaction::Deploy(_) | Transaction::L1Handler(_) => PEDERSEN_EMPTY, + }; + + Pedersen::hash(&tx_hash, &signature_hash) + } + + fn compute_hash_with_signature_latest(&self, tx_hash: Felt) -> Felt { + let signature = match self { + Transaction::Invoke(tx) => tx.signature(), + Transaction::Declare(tx) => tx.signature(), + Transaction::DeployAccount(tx) => tx.signature(), + Transaction::Deploy(_) | Transaction::L1Handler(_) => &[], }; - leaf + let elements = if signature.is_empty() { + vec![tx_hash, Felt::ZERO] + } else { + std::iter::once(tx_hash).chain(signature.iter().copied()).collect() + }; + + Poseidon::hash_array(&elements) } } @@ -707,4 +702,9 @@ mod tests { Felt::from_hex_unchecked("0x734743d11641ecb3d92bafae091346fec3b2c75f7808e39f8b23d9287636e45"); assert_eq!(contract_address, expected_contract_address,); } + + #[test] + fn test_pedersen_empty() { + assert_eq!(PEDERSEN_EMPTY, Pedersen::hash_array(&[])) + } } diff --git a/crates/proc-macros/src/lib.rs b/crates/proc-macros/src/lib.rs index 83abd1df3..432861eb9 100644 --- a/crates/proc-macros/src/lib.rs +++ b/crates/proc-macros/src/lib.rs @@ -3,7 +3,7 @@ //! This macro is a wrapper around the "rpc" macro supplied by the jsonrpsee library that generates //! a server and client traits from a given trait definition. The wrapper gets a version id and -//! prepend the version id to the trait name and to every method name (note method name refers to +//! prepends the version id to the trait name and to every method name (note method name refers to //! the name the API has for the function not the actual function name). We need this in order to be //! able to merge multiple versions of jsonrpc APIs into one server and not have a clash in method //! resolution. @@ -12,9 +12,9 @@ //! //! Given this code: //! ```rust,ignore -//! #[versioned_starknet_rpc("V0_7_1")] +//! #[versioned_rpc("V0_7_1", "starknet")] //! pub trait JsonRpc { -//! #[method(name = "blockNumber")] +//! #[method(name = "blockNumber", aliases = ["block_number"])] //! fn block_number(&self) -> anyhow::Result; //! } //! ``` @@ -23,15 +23,19 @@ //! ```rust,ignore //! #[rpc(server, namespace = "starknet")] //! pub trait JsonRpcV0_7_1 { -//! #[method(name = "V0_7_1_blockNumber")] +//! #[method(name = "V0_7_1_blockNumber", aliases = ["block_number"])] //! fn block_number(&self) -> anyhow::Result; //! } //! ``` +//! +//! > [!NOTE] +//! > This macro _will not_ override any other jsonrpsee attribute, meaning +//! > it does not currently support renaming `aliases` or `unsubscribe_aliases` use proc_macro::TokenStream; use proc_macro2::Span; use quote::quote; -use syn::{parse::Parse, parse_macro_input, Attribute, Ident, ItemTrait, LitStr, TraitItem}; +use syn::spanned::Spanned; #[derive(Debug)] struct VersionedRpcAttr { @@ -39,11 +43,11 @@ struct VersionedRpcAttr { namespace: String, } -impl Parse for VersionedRpcAttr { +impl syn::parse::Parse for VersionedRpcAttr { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let version = input.parse::()?.value(); + let version = input.parse::()?.value(); input.parse::()?; - let namespace = input.parse::()?.value(); + let namespace = input.parse::()?.value(); if !version.starts_with('V') { return Err(syn::Error::new(Span::call_site(), "Version must start with 'V'")); @@ -68,11 +72,12 @@ impl Parse for VersionedRpcAttr { return Err(syn::Error::new( Span::call_site(), indoc::indoc!( - " + r#" Namespace cannot be empty. Please provide a non-empty namespace string. - Example: #[versioned_rpc(\"V0_7_1\", \"starknet\")] - " + + ex: #[versioned_rpc("V0_7_1", "starknet")] + "# ), )); } @@ -81,52 +86,129 @@ impl Parse for VersionedRpcAttr { } } -fn version_method_name(attr: &Attribute, version: &str) -> syn::Result { - let mut new_attr = attr.clone(); - attr.parse_nested_meta(|meta| { - if meta.path.is_ident("name") { - let value = meta.value()?; - let method_name: LitStr = value.parse()?; - let new_name = format!("{version}_{}", method_name.value()); - new_attr.meta = syn::parse_quote!(method(name = #new_name)); - } - Ok(()) - })?; - Ok(new_attr) +enum CallType { + Method, + Subscribe, } #[proc_macro_attribute] pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream { - let VersionedRpcAttr { version, namespace } = parse_macro_input!(attr as VersionedRpcAttr); - let mut item_trait = parse_macro_input!(input as ItemTrait); + let VersionedRpcAttr { version, namespace } = syn::parse_macro_input!(attr as VersionedRpcAttr); + let mut item_trait = syn::parse_macro_input!(input as syn::ItemTrait); let trait_name = &item_trait.ident; - let versioned_trait_name = Ident::new(&format!("{trait_name}{version}"), trait_name.span()); - - for item in &mut item_trait.items { - if let TraitItem::Fn(method) = item { - method.attrs = method - .attrs - .iter() - .filter_map(|attr| { - if attr.path().is_ident("method") { - version_method_name(attr, &version).ok() - } else { - Some(attr.clone()) + let train_name_with_version = syn::Ident::new(&format!("{trait_name}{version}"), trait_name.span()); + + // This next section is reponsible for versioning the method name declared + // with jsonrpsee + let err = item_trait.items.iter_mut().try_fold((), |_, item| { + let syn::TraitItem::Fn(method) = item else { + return Err(syn::Error::new( + item.span(), + indoc::indoc! {r#" + Traits marked with `versioned_rpc` can only contain methods + + ex: + + #[versioned_rpc("V0_7_0", "starknet")] + trait MyTrait { + #[method(name = "foo", blocking)] + fn foo(); + } + "#}, + )); + }; + + method.attrs.iter_mut().try_fold((), |_, attr| { + // We leave these errors to be handled by jsonrpsee + let path = attr.path(); + let ident = if path.is_ident("method") { + CallType::Method + } else if path.is_ident("subscription") { + CallType::Subscribe + } else { + return Ok(()); + }; + + let syn::Meta::List(meta_list) = &attr.meta else { + return Ok(()); + }; + + // This convoluted section is just the way by which we traverse + // the macro attribute list. We are looking for: + // + // - An assignment + // - With lvalue a Path expression with literal value `name` or + // 'unsubscribe' + // - With rvalue a literal + // + // Any other attribute is skipped over and is not overwritten + let attr_args = meta_list + .parse_args_with(syn::punctuated::Punctuated::::parse_terminated) + .map_err(|_| { + syn::Error::new( + meta_list.span(), + indoc::indoc! {r#" + The `method` and `subscription` attributes expect comma-separated values. + + ex: `#[method(name = "foo", blocking)]` + "#}, + ) + })? + .into_iter() + .map(|expr| { + // There isn't really a more elegant way of doing this as + // `left` and `right` are boxed values and therefore cannot + // be pattern matched without being de-referenced + let syn::Expr::Assign(expr) = expr else { return expr }; + + let syn::Expr::Path(syn::ExprPath { path, .. }) = *expr.left.clone() else { + return syn::Expr::Assign(expr); + }; + let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(right), attrs }) = *expr.right.clone() else { + return syn::Expr::Assign(expr); + }; + + if !path.is_ident("name") && !path.is_ident("unsubscribe") { + return syn::Expr::Assign(expr); } + + let method_with_version = format!("{version}_{}", right.value()); + syn::Expr::Assign(syn::ExprAssign { + right: Box::new(syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(syn::LitStr::new(&method_with_version, right.span())), + attrs, + })), + ..expr + }) }) - .collect(); - } + .collect::>(); + + // This is the part where we actually replace the attribute with + // its versioned alternative. Note that the syntax #(#foo),* + // indicates a pattern repetition here, where all the elements in + // attr_args are expanded into rust code + attr.meta = match ident { + CallType::Method => syn::parse_quote!(method(#(#attr_args),*)), + CallType::Subscribe => syn::parse_quote!(subscription(#(#attr_args),*)), + }; + + Ok(()) + }) + }); + + if let Err(e) = err { + return e.into_compile_error().into(); } - let versioned_trait = ItemTrait { - attrs: vec![syn::parse_quote!(#[rpc(server, namespace = #namespace)])], - ident: versioned_trait_name, + let trait_with_version = syn::ItemTrait { + attrs: vec![syn::parse_quote!(#[jsonrpsee::proc_macros::rpc(server, namespace = #namespace)])], + ident: train_name_with_version, ..item_trait }; quote! { - #versioned_trait + #trait_with_version } .into() } @@ -134,7 +216,7 @@ pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream { #[cfg(test)] mod tests { use super::*; - use quote::{quote, ToTokens}; + use quote::quote; use syn::parse_quote; #[test] @@ -165,19 +247,13 @@ mod tests { assert_eq!( result.unwrap_err().to_string(), indoc::indoc!( - " + r#" Namespace cannot be empty. Please provide a non-empty namespace string. - Example: #[versioned_rpc(\"V0_7_1\", \"starknet\")] - " + + ex: #[versioned_rpc("V0_7_1", "starknet")] + "# ) ); } - - #[test] - fn test_version_method_name() { - let attr: Attribute = parse_quote!(#[method(name = "blockNumber")]); - let result = version_method_name(&attr, "V0_7_1").unwrap(); - assert_eq!(result.to_token_stream().to_string(), "# [method (name = \"V0_7_1_blockNumber\")]"); - } }