Skip to content

Commit 9f34d1c

Browse files
authored
feat: handle multiple ws addesses (#7)
1 parent 30817e4 commit 9f34d1c

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The project is structured as follows:
1818
## Running the chain pusher
1919

2020
```bash
21-
cargo run -- --auth_header "Bearer <your_auth_token>" --ws_url "ws_url" --cluster "https://devnet.magicblock.app"
21+
cargo run -- --auth-header "Bearer <your_auth_token>" --ws-urls "ws_url1,ws_url2" --cluster "https://devnet.magicblock.app"
2222
```
2323

2424
## Consuming Price Data in a Solana Program

src/args.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use clap::{Parser, ValueEnum};
22
use solana_sdk::signature::Keypair;
3+
use tracing::warn;
34

45
#[derive(Debug, Clone, ValueEnum)]
56
pub enum ChannelType {
@@ -31,8 +32,18 @@ pub struct Args {
3132
pub private_key: Option<String>,
3233
#[arg(long, help = "Authorization header for the WebSocket connection")]
3334
pub auth_header: Option<String>,
34-
#[arg(long, help = "WebSocket URL for the price feed")]
35+
36+
#[arg(long,
37+
help = "DEPRECATED: Single WebSocket URL for the price feed",
38+
conflicts_with = "ws_urls",
39+
hide = true)]
3540
pub ws_url: Option<String>,
41+
42+
#[arg(long,
43+
help = "Comma-separated list of WebSocket URLs for the price feed",
44+
value_delimiter = ',')]
45+
pub ws_urls: Vec<String>,
46+
3647
#[arg(long, help = "Solana cluster URL")]
3748
pub cluster: Option<String>,
3849
#[arg(long, help = "Comma-separated list of price feeds")]
@@ -44,11 +55,28 @@ pub struct Args {
4455
pub channel: Option<ChannelType>,
4556
}
4657

47-
pub fn get_ws_url(cli_url: Option<String>) -> String {
48-
std::env::var("ORACLE_WS_URL")
49-
.ok()
50-
.or(cli_url)
51-
.unwrap_or_else(|| "ws://localhost:8765".to_string())
58+
pub fn get_ws_urls(cli_url: Option<String>, cli_urls: Vec<String>) -> Vec<String> {
59+
if cli_url.is_some() {
60+
warn!("'--ws-url' is deprecated, use '--ws-urls' with comma-separated list instead");
61+
}
62+
63+
let env_url = std::env::var("ORACLE_WS_URL").ok();
64+
let env_urls = std::env::var("ORACLE_WS_URLS").ok();
65+
66+
if !cli_urls.is_empty() {
67+
cli_urls
68+
} else if let Some(urls_str) = env_urls {
69+
urls_str
70+
.split(',')
71+
.map(|url| url.trim().to_string())
72+
.filter(|url| !url.is_empty())
73+
.collect()
74+
} else {
75+
let single_url = cli_url
76+
.or(env_url)
77+
.unwrap_or_else(|| "ws://localhost:8765".to_string());
78+
vec![single_url]
79+
}
5280
}
5381

5482
pub fn get_auth_header(cli_auth: Option<String>) -> String {

src/main.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use tokio_native_tls::TlsConnector;
2727
use tracing::{debug, error, info, warn};
2828
use url::Url;
2929

30-
use crate::args::{get_auth_header, get_channel, get_price_feeds, get_private_key, get_solana_cluster, get_ws_url, Args};
30+
use crate::args::{get_auth_header, get_channel, get_price_feeds, get_private_key, get_solana_cluster, get_ws_urls, Args};
3131
use crate::pyth_lazer::chain_pusher::PythChainPusher;
3232
use crate::stork::chain_pusher::StorkChainPusher;
3333
use crate::types::ChainPusher;
@@ -39,27 +39,38 @@ async fn main() {
3939
let args = Args::parse();
4040
let private_key = get_private_key(args.private_key);
4141
let auth_header = get_auth_header(args.auth_header);
42-
let ws_url = get_ws_url(args.ws_url);
42+
let ws_urls = get_ws_urls(args.ws_url, args.ws_urls);
4343
let cluster_url = get_solana_cluster(args.cluster);
4444
let price_feeds = get_price_feeds(args.price_feeds);
4545
let channel = get_channel(args.channel);
4646

4747
let payer = Keypair::from_base58_string(&private_key);
4848
info!(wallet_pubkey = ?payer.pubkey(), "Identity initialized");
4949

50-
let chain_pusher: Arc<dyn ChainPusher> = if ws_url.contains("stork") {
50+
let chain_pusher: Arc<dyn ChainPusher> = if ws_urls.iter().any(|url| url.contains("stork")) {
5151
Arc::new(StorkChainPusher::new(&cluster_url, payer).await)
5252
} else {
5353
Arc::new(PythChainPusher::new(&cluster_url, payer).await)
5454
};
5555

5656
loop {
57-
if let Err(e) =
58-
run_websocket_client(&chain_pusher, &ws_url, &auth_header, &price_feeds, &channel).await
59-
{
60-
error!(error = ?e, "WebSocket connection error, attempting reconnection");
57+
let mut last_error = None;
58+
59+
for ws_url in &ws_urls {
60+
match run_websocket_client(&chain_pusher, ws_url, &auth_header, &price_feeds, &channel).await {
61+
Ok(_) => break,
62+
Err(e) => {
63+
error!(error = ?e, url = ws_url, "WebSocket connection failed, trying next URL");
64+
last_error = Some(e);
65+
}
66+
}
67+
}
68+
69+
// if all URLs fail, wait before trying again
70+
if let Some(e) = last_error {
71+
error!(error = ?e, "All WebSocket URLs failed, retrying in 5 seconds");
72+
time::sleep(Duration::from_secs(5)).await;
6173
}
62-
time::sleep(Duration::from_secs(5)).await;
6374
}
6475
}
6576

0 commit comments

Comments
 (0)