From 85526ffc8ff8e695c316a16657b34d59ea746c4f Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Fri, 16 Jan 2026 22:23:43 +0100 Subject: [PATCH] init multi-process template Move existing template to "single-process." --- .github/{template-values.toml => values.toml} | 0 .github/workflows/check.yml | 12 +- README.md | 37 +++- cargo-generate.toml | 9 +- init-script.rhai | 1 - .../.gitignore | 0 multi-process/Cargo.toml | 26 +++ multi-process/cargo-generate.toml | 5 + multi-process/gateway/Cargo.toml | 19 +++ multi-process/gateway/src/context.rs | 26 +++ multi-process/gateway/src/forward.rs | 158 ++++++++++++++++++ multi-process/gateway/src/main.rs | 108 ++++++++++++ {src => multi-process/gateway/src}/resume.rs | 0 multi-process/worker/Cargo.toml | 21 +++ multi-process/worker/src/cache.rs | 24 +++ multi-process/worker/src/command.rs | 55 ++++++ .../worker/src}/command/ping.rs | 0 multi-process/worker/src/context.rs | 27 +++ multi-process/worker/src/dispatch.rs | 74 ++++++++ multi-process/worker/src/main.rs | 72 ++++++++ single-process/.gitignore | 1 + Cargo.toml => single-process/Cargo.toml | 0 single-process/cargo-generate.toml | 6 + {src => single-process/src}/command.rs | 0 single-process/src/command/ping.rs | 44 +++++ .../src}/command/restart.rs | 0 {src => single-process/src}/context.rs | 0 {src => single-process/src}/dispatch.rs | 0 {src => single-process/src}/main.rs | 0 single-process/src/resume.rs | 90 ++++++++++ 30 files changed, 797 insertions(+), 18 deletions(-) rename .github/{template-values.toml => values.toml} (100%) delete mode 100644 init-script.rhai rename .gitignore.template => multi-process/.gitignore (100%) create mode 100644 multi-process/Cargo.toml create mode 100644 multi-process/cargo-generate.toml create mode 100644 multi-process/gateway/Cargo.toml create mode 100644 multi-process/gateway/src/context.rs create mode 100644 multi-process/gateway/src/forward.rs create mode 100644 multi-process/gateway/src/main.rs rename {src => multi-process/gateway/src}/resume.rs (100%) create mode 100644 multi-process/worker/Cargo.toml create mode 100644 multi-process/worker/src/cache.rs create mode 100644 multi-process/worker/src/command.rs rename {src => multi-process/worker/src}/command/ping.rs (100%) create mode 100644 multi-process/worker/src/context.rs create mode 100644 multi-process/worker/src/dispatch.rs create mode 100644 multi-process/worker/src/main.rs create mode 100644 single-process/.gitignore rename Cargo.toml => single-process/Cargo.toml (100%) create mode 100644 single-process/cargo-generate.toml rename {src => single-process/src}/command.rs (100%) create mode 100644 single-process/src/command/ping.rs rename {src => single-process/src}/command/restart.rs (100%) rename {src => single-process/src}/context.rs (100%) rename {src => single-process/src}/dispatch.rs (100%) rename {src => single-process/src}/main.rs (100%) create mode 100644 single-process/src/resume.rs diff --git a/.github/template-values.toml b/.github/values.toml similarity index 100% rename from .github/template-values.toml rename to .github/values.toml diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 2a34c02..209eaa6 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -9,6 +9,9 @@ env: jobs: clippy: + strategy: + matrix: + variant: ["multi-process", "single-process"] runs-on: ubuntu-latest steps: @@ -24,12 +27,15 @@ jobs: with: tool: cargo-generate - - run: cargo generate --path . --name instance --template-values-file .github/template-values.toml + - run: cargo generate ${{ matrix.variant }} --path . --name instance --values-file .github/values.toml - run: cargo clippy working-directory: instance - fmt: + rustfmt: + strategy: + matrix: + variant: ["multi-process", "single-process"] runs-on: ubuntu-latest steps: @@ -45,7 +51,7 @@ jobs: with: tool: cargo-generate - - run: cargo generate --path . --name instance --template-values-file .github/template-values.toml + - run: cargo generate ${{ matrix.variant }} --path . --name instance --values-file .github/values.toml - run: cargo fmt -- --check working-directory: instance diff --git a/README.md b/README.md index 3d8fe51..c04e4d1 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,42 @@ # Twilight Template -An opinionated [Twilight] bot template. +A set of opinionated [Twilight] bot templates. -The template is built in a modular fashion and provides the following features: +## Usage + +1. Install [cargo-generate]: `cargo install cargo-generate` +1. Create a bot based upon a select template: `cargo generate twilight-rs/template` + +## Variants + +- Single-process: fast restarts +- Multi-process: zero downtime restarts + +### Single-process -* Gateway session resumption between processes +This template is built in a modular fashion and provides the following features: + +* Gateway session resumption between restarts * Clean shutdown that awaits event handler completion * An administrative `/restart ` command -## Usage +### Multi-process -1. Install [cargo-generate]: `cargo install cargo-generate` -1. Create a bot based upon this template: `cargo generate twilight-rs/template` +This template is split into `gateway` and `worker` crates, where the `gateway` +forwards events and provides state to the `worker`. It otherwise mirrors the +single-process template: + +* gateway: + * Gateway session resumption between restarts + * Clean shutdown that awaits event forwarder completion +* worker: + * Clean shutdown that awaits event handler completion + +**Note**: this adds IPC overhead compared to the single-process template. You +may spread the load of the single-process template: + +1. Across threads with the multi-threaded tokio runtime +1. Across machines by partitioning your shards (i.e. each machine runs 1/X shards) [cargo-generate]: https://github.com/cargo-generate/cargo-generate [Twilight]: https://github.com/twilight-rs/twilight diff --git a/cargo-generate.toml b/cargo-generate.toml index 27b401a..e9f81a1 100644 --- a/cargo-generate.toml +++ b/cargo-generate.toml @@ -1,9 +1,2 @@ [template] -ignore = [".github", "Cargo.lock", "README.md", "target"] - -[hooks] -init = ["init-script.rhai"] - -[placeholders] -admin_guild_id = { prompt = "Enter the admin guild ID (where admin commands are available)", type = "string" } -application_id = { prompt = "Enter the application ID (available in the applications dashboard: )", type = "string" } +sub_templates = ["multi-process", "single-process"] diff --git a/init-script.rhai b/init-script.rhai deleted file mode 100644 index b74646e..0000000 --- a/init-script.rhai +++ /dev/null @@ -1 +0,0 @@ -file::rename(".gitignore.template", ".gitignore"); diff --git a/.gitignore.template b/multi-process/.gitignore similarity index 100% rename from .gitignore.template rename to multi-process/.gitignore diff --git a/multi-process/Cargo.toml b/multi-process/Cargo.toml new file mode 100644 index 0000000..f85470c --- /dev/null +++ b/multi-process/Cargo.toml @@ -0,0 +1,26 @@ +[workspace] +members = ["gateway", "worker"] +resolver = "3" + +[workspace.package] +edition = "2024" + +[workspace.dependencies] +anyhow = "1" +axum = "0.8" +futures-util = "0.3" +reqwest = "0.13" +rustls = "0.23" +serde = "1" +serde_json = "1" +tokio = "1" +tokio-stream = "0.1" +tokio-util = "0.7" +tokio-websockets = "0.13" +tracing = "0.1" +tracing-subscriber = "0.3" +twilight-cache-inmemory = "0.17" +twilight-gateway = "0.17" +twilight-http = "0.17" +twilight-model = "0.17" +twilight-util = "0.17" diff --git a/multi-process/cargo-generate.toml b/multi-process/cargo-generate.toml new file mode 100644 index 0000000..ce076b3 --- /dev/null +++ b/multi-process/cargo-generate.toml @@ -0,0 +1,5 @@ +[template] +ignore = ["Cargo.lock", "target"] + +[placeholders] +application_id = { prompt = "Enter the application ID (available in the applications dashboard: )", type = "string" } diff --git a/multi-process/gateway/Cargo.toml b/multi-process/gateway/Cargo.toml new file mode 100644 index 0000000..1ff8aa2 --- /dev/null +++ b/multi-process/gateway/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "gateway" +version = "0.1.0" +edition.workspace = true + +[dependencies] +anyhow.workspace = true +axum = { workspace = true, features = ["ws"] } +rustls.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +tokio = { workspace = true, features = ["fs", "macros", "net", "rt", "signal"] } +tokio-stream.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +twilight-cache-inmemory.workspace = true +twilight-gateway.workspace = true +twilight-http.workspace = true +twilight-model.workspace = true diff --git a/multi-process/gateway/src/context.rs b/multi-process/gateway/src/context.rs new file mode 100644 index 0000000..55c06b1 --- /dev/null +++ b/multi-process/gateway/src/context.rs @@ -0,0 +1,26 @@ +use std::{ops::Deref, sync::OnceLock}; +use tokio::sync::Notify; +use twilight_cache_inmemory::InMemoryCache; + +pub static CONTEXT: Ref = Ref(OnceLock::new()); + +#[derive(Debug)] +pub struct Context { + pub cache: InMemoryCache, + pub notify: Notify, +} + +pub fn init(cache: InMemoryCache, notify: Notify) { + let context = Context { cache, notify }; + assert!(CONTEXT.0.set(context).is_ok()); +} + +pub struct Ref(OnceLock); + +impl Deref for Ref { + type Target = Context; + + fn deref(&self) -> &Self::Target { + self.0.get().unwrap() + } +} diff --git a/multi-process/gateway/src/forward.rs b/multi-process/gateway/src/forward.rs new file mode 100644 index 0000000..5a810b8 --- /dev/null +++ b/multi-process/gateway/src/forward.rs @@ -0,0 +1,158 @@ +use crate::{CONTEXT, ResumeInfo}; +use anyhow::anyhow; +use axum::{ + body::Bytes, + extract::{WebSocketUpgrade, ws::Message as SocketMessage}, + response::Response, +}; +use serde::de::DeserializeSeed as _; +use std::{collections::VecDeque, error::Error, pin::pin}; +use tokio::{ + signal, + sync::broadcast::{self, error::RecvError}, +}; +use tokio_stream::StreamExt as _; +use twilight_gateway::{CloseFrame, Event, EventTypeFlags, Message as ShardMessage, Shard}; +use twilight_model::gateway::{OpCode, event::GatewayEventDeserializer}; + +const BUFFER_LIMIT: usize = 1000; +const EVENT_TYPES: EventTypeFlags = EventTypeFlags::INTERACTION_CREATE + .union(EventTypeFlags::READY) + .union(EventTypeFlags::ROLE_CREATE) + .union(EventTypeFlags::ROLE_DELETE) + .union(EventTypeFlags::ROLE_UPDATE); + +fn parse(input: &str) -> anyhow::Result> { + let deserializer = + GatewayEventDeserializer::from_json(input).ok_or_else(|| anyhow!("missing opcode"))?; + let opcode = OpCode::from(deserializer.op()).ok_or_else(|| anyhow!("unknown opcode"))?; + let event_type = EventTypeFlags::try_from((opcode, deserializer.event_type())) + .map_err(|_| anyhow!("missing event type"))?; + + Ok(EVENT_TYPES + .contains(event_type) + .then(|| { + let mut json_deserializer = serde_json::Deserializer::from_str(input); + deserializer.deserialize(&mut json_deserializer) + }) + .transpose()? + .map(Into::into)) +} + +enum ShardState { + Active, + Shutdown, +} + +impl ShardState { + fn is_active(&self) -> bool { + matches!(self, Self::Active) + } + + fn is_shutdown(&self) -> bool { + matches!(self, Self::Shutdown) + } +} + +#[tracing::instrument(fields(id = shard.id().number()), skip_all)] +pub async fn shard(event_tx: broadcast::Sender, mut shard: Shard) -> ResumeInfo { + let mut notified = None; + let mut buffer = VecDeque::new(); + let mut shutdown = pin!(signal::ctrl_c()); + let mut state = ShardState::Active; + + loop { + tokio::select! { + _ = notified.as_mut().unwrap(), if notified.is_some() => { + notified = None; + loop { + let inner = CONTEXT.notify.notified(); + match event_tx.send(buffer.pop_front().unwrap()) { + Ok(_) if buffer.is_empty() => break, + Ok(_) => {} + Err(error) => { + notified = Some(Box::pin(inner)); + buffer.push_front(error.0); + break; + } + } + } + } + _ = &mut shutdown, if !state.is_shutdown() => { + if state.is_active() { + shard.close(CloseFrame::RESUME); + } + state = ShardState::Shutdown; + } + event = shard.next() => { + match event { + Some(Ok(ShardMessage::Close(_))) if !state.is_active() => break, + Some(Ok(ShardMessage::Close(_))) => {} + Some(Ok(ShardMessage::Text(json))) => { + match parse(&json) { + Ok(Some(event)) => { + CONTEXT.cache.update(&event); + + let inner = notified.is_none().then(|| CONTEXT.notify.notified()); + if let Err(error) = event_tx.send(json.into()) { + if let Some(inner) = inner { + notified = Some(Box::pin(inner)); + } + + if buffer.len() == BUFFER_LIMIT { + buffer.pop_front(); + } + buffer.push_back(error.0); + } + } + Ok(_) => {} + Err(error) => tracing::warn!(error = &*error, "failed to deserialize event"), + } + } + Some(Err(error)) => tracing::warn!(error = &error as &dyn Error, "shard failed to receive event"), + None => break, + } + } + } + } + + return ResumeInfo::from(&shard); +} + +#[tracing::instrument(skip_all)] +pub async fn socket(ws: WebSocketUpgrade, weak: broadcast::WeakSender) -> Response { + ws.on_upgrade(async move |mut socket| { + if let Some(mut event_rx) = weak.upgrade().map(|tx| tx.subscribe()) { + tracing::info!("worker connected"); + CONTEXT.notify.notify_waiters(); + + loop { + tokio::select! { + message = socket.recv() => { + match message { + Some(Ok(SocketMessage::Close(_))) => return, + Some(Err(error)) => tracing::warn!(error = &error as &dyn Error, "socket failed to receive event"), + None => return, + _ => {} + } + } + event = event_rx.recv() => { + match event { + Ok(event) => { + if let Err(error) = socket.send(SocketMessage::Text(event.try_into().unwrap())).await { + tracing::warn!(error = &error as &dyn Error, "socket failed to send event"); + return; + } + } + Err(RecvError::Closed) => return, + Err(RecvError::Lagged(count)) => tracing::warn!("socket lagged {count} events"), + } + } + } + } + } + + // Drive socket to completion. + while socket.recv().await.is_some() {} + }) +} diff --git a/multi-process/gateway/src/main.rs b/multi-process/gateway/src/main.rs new file mode 100644 index 0000000..c37a565 --- /dev/null +++ b/multi-process/gateway/src/main.rs @@ -0,0 +1,108 @@ +mod context; +mod forward; +mod resume; + +pub(crate) use self::{context::CONTEXT, resume::Info as ResumeInfo}; + +use anyhow::Context as _; +use axum::{ + Json, Router, + extract::Path, + response::Result, + routing::{any, get}, +}; +use std::{env, iter, time::Duration}; +use tokio::{ + net::TcpListener, + signal, + sync::{Notify, broadcast}, +}; +use twilight_cache_inmemory::{InMemoryCache, ResourceType}; +use twilight_gateway::{ConfigBuilder, Intents, queue::InMemoryQueue}; +use twilight_http::Client; +use twilight_model::{ + guild::Role, + id::{Id, marker::RoleMarker}, +}; + +const SHARD_EVENT_BUFFER: usize = 16; +const INTENTS: Intents = Intents::GUILDS; +const RESOURCE_TYPES: ResourceType = ResourceType::ROLE; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::init(); + + context::init( + InMemoryCache::builder() + .resource_types(RESOURCE_TYPES) + .build(), + Notify::new(), + ); + + let token = env::var("TOKEN").context("failed to get `TOKEN`")?; + + let http = Client::new(token.clone()); + let info = async { Ok::<_, anyhow::Error>(http.gateway().authed().await?.model().await?) } + .await + .context("failed to get info")?; + + // The queue defaults are static and may be incorrect for large or newly + // restarted bots. + let queue = InMemoryQueue::new( + info.session_start_limit.max_concurrency, + info.session_start_limit.remaining, + Duration::from_millis(info.session_start_limit.reset_after), + info.session_start_limit.total, + ); + let config = ConfigBuilder::new(token, INTENTS).queue(queue).build(); + + let shards = resume::restore(config, info.shards).await; + + let (event_tx, _) = broadcast::channel(shards.len() * SHARD_EVENT_BUFFER); + let router = Router::new() + .route( + "/events", + any({ + let weak = event_tx.downgrade(); + move |body| forward::socket(body, weak) + }), + ) + .route("/roles/{id}", get(get_role)); + + tokio::spawn({ + let listener = TcpListener::bind("[::1]:3000").await?; + async { axum::serve(listener, router).await.unwrap() } + }); + let tasks = iter::repeat_n(event_tx, shards.len()) + .zip(shards) + .map(|(event_tx, shard)| tokio::spawn(forward::shard(event_tx, shard))) + .collect::>(); + + signal::ctrl_c().await?; + tracing::info!("shutting down; press CTRL-C to abort"); + + let join_all_tasks = async { + let mut resume_info = Vec::new(); + for task in tasks { + resume_info.push(task.await?); + } + Ok::<_, anyhow::Error>(resume_info) + }; + let resume_info = tokio::select! { + _ = signal::ctrl_c() => Vec::new(), + resume_info = join_all_tasks => resume_info?, + }; + + // Save shard information to be restored. + resume::save(&resume_info) + .await + .context("failed to save resume info")?; + + Ok(()) +} + +async fn get_role(Path(role_id): Path>) -> Result> { + let role = CONTEXT.cache.role(role_id).ok_or("not found")?; + Ok(Json(role.value().resource().clone())) +} diff --git a/src/resume.rs b/multi-process/gateway/src/resume.rs similarity index 100% rename from src/resume.rs rename to multi-process/gateway/src/resume.rs diff --git a/multi-process/worker/Cargo.toml b/multi-process/worker/Cargo.toml new file mode 100644 index 0000000..33d4b39 --- /dev/null +++ b/multi-process/worker/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "worker" +version = "0.1.0" +edition.workspace = true + +[dependencies] +anyhow.workspace = true +futures-util = { workspace = true, features = ["sink"] } +reqwest = { workspace = true, features = ["json"] } +rustls.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio = { workspace = true, features = ["macros","rt", "signal"] } +tokio-stream.workspace = true +tokio-util = { workspace = true, features = ["rt"] } +tokio-websockets = { workspace = true, features = ["aws_lc_rs", "client", "fastrand"] } +tracing.workspace = true +tracing-subscriber.workspace = true +twilight-http.workspace = true +twilight-model.workspace = true +twilight-util = { workspace = true, features = ["builder"] } diff --git a/multi-process/worker/src/cache.rs b/multi-process/worker/src/cache.rs new file mode 100644 index 0000000..65e78c3 --- /dev/null +++ b/multi-process/worker/src/cache.rs @@ -0,0 +1,24 @@ +#![allow(dead_code)] +use reqwest::{Client, IntoUrl, Response, Result}; +use twilight_model::{ + guild::Role, + id::{Id, marker::RoleMarker}, +}; + +#[derive(Debug)] +pub struct Cache(Client); + +impl Cache { + pub fn new() -> Self { + Self(Client::new()) + } + + async fn get(&self, url: impl IntoUrl) -> Result { + self.0.get(url).send().await + } + + pub async fn role(&self, role_id: Id) -> Result { + let url = format!("http://[::1]:3000/roles/{role_id}"); + self.get(url).await?.json().await + } +} diff --git a/multi-process/worker/src/command.rs b/multi-process/worker/src/command.rs new file mode 100644 index 0000000..7e55d51 --- /dev/null +++ b/multi-process/worker/src/command.rs @@ -0,0 +1,55 @@ +mod ping; + +use twilight_model::{ + application::{ + command::Command, + interaction::{InteractionData, InteractionType}, + }, + gateway::payload::incoming::InteractionCreate, +}; + +pub fn commands() -> [Command; 1] { + [ping::command()] +} + +#[derive(Clone, Copy, Debug)] +enum Kind { + Ping, +} + +impl From<&str> for Kind { + fn from(name: &str) -> Self { + match name { + ping::NAME => Kind::Ping, + _ => panic!("unknown command name: '{name}'"), + } + } +} + +pub async fn interaction(mut event: Box) -> anyhow::Result<()> { + match event.kind { + InteractionType::ApplicationCommandAutocomplete => { + let InteractionData::ApplicationCommand(data) = event.data.take().unwrap() else { + unreachable!(); + }; + let kind = data.name.as_str().into(); + + match kind { + Kind::Ping => ping::autocomplete(event, data).await?, + } + } + InteractionType::ApplicationCommand => { + let InteractionData::ApplicationCommand(data) = event.data.take().unwrap() else { + unreachable!(); + }; + let kind = data.name.as_str().into(); + + match kind { + Kind::Ping => ping::run(event, data).await?, + } + } + _ => {} + } + + Ok(()) +} diff --git a/src/command/ping.rs b/multi-process/worker/src/command/ping.rs similarity index 100% rename from src/command/ping.rs rename to multi-process/worker/src/command/ping.rs diff --git a/multi-process/worker/src/context.rs b/multi-process/worker/src/context.rs new file mode 100644 index 0000000..71749b3 --- /dev/null +++ b/multi-process/worker/src/context.rs @@ -0,0 +1,27 @@ +use crate::Cache; +use std::{ops::Deref, sync::OnceLock}; +use twilight_http::Client; + +pub static CONTEXT: Ref = Ref(OnceLock::new()); + +#[derive(Debug)] +pub struct Context { + #[allow(dead_code)] + pub cache: Cache, + pub http: Client, +} + +pub fn init(cache: Cache, http: Client) { + let context = Context { cache, http }; + assert!(CONTEXT.0.set(context).is_ok()); +} + +pub struct Ref(OnceLock); + +impl Deref for Ref { + type Target = Context; + + fn deref(&self) -> &Self::Target { + self.0.get().unwrap() + } +} diff --git a/multi-process/worker/src/dispatch.rs b/multi-process/worker/src/dispatch.rs new file mode 100644 index 0000000..aef9a8b --- /dev/null +++ b/multi-process/worker/src/dispatch.rs @@ -0,0 +1,74 @@ +use futures_util::SinkExt as _; +use serde::de::DeserializeSeed as _; +use std::{error::Error, pin::pin}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + signal, +}; +use tokio_stream::StreamExt as _; +use tokio_util::task::TaskTracker; +use tokio_websockets::{Message, WebSocketStream}; +use twilight_model::gateway::event::{Event, GatewayEventDeserializer}; + +enum State { + Active, + Shutdown, +} + +impl State { + fn is_shutdown(&self) -> bool { + matches!(self, Self::Shutdown) + } +} + +#[tracing::instrument(name = "dispatcher", skip_all)] +pub async fn run + Send + 'static>( + mut event_handler: impl FnMut(Event) -> Fut, + mut socket: WebSocketStream, +) { + let mut shutdown = pin!(signal::ctrl_c()); + let mut state = State::Active; + let tracker = TaskTracker::new(); + + loop { + tokio::select! { + _ = &mut shutdown, if !state.is_shutdown() => { + if let Err(error) = socket.send(Message::close(None, "")).await { + tracing::warn!(error = &error as &dyn Error, "socket failed to send message"); + } + state = State::Shutdown; + } + message = socket.next() => { + match message { + Some(Ok(message)) => { + if message.is_text() { + let event = message.as_text().unwrap(); + let Some(deserializer) = GatewayEventDeserializer::from_json(event) else { + tracing::warn!(event, "failed to deserialize event"); + continue; + }; + let mut json_deserializer = serde_json::Deserializer::from_str(event); + match deserializer.deserialize(&mut json_deserializer) { + Ok(event) => _ = tracker.spawn(event_handler(event.into())), + Err(error) => tracing::warn!(error = &error as &dyn Error, "failed to deserialize event"), + } + } + } + Some(Err(error)) => tracing::warn!(error = &error as &dyn Error, "socket failed to receive message"), + None => { + tracing::info!("gateway shut down"); + break; + }, + } + } + } + } + + if let Err(error) = socket.close().await { + tracing::warn!(error = &error as &dyn Error, "socket failed to close"); + } + + tracker.close(); + tracing::info!("waiting for {} task(s) to finish", tracker.len()); + tracker.wait().await; +} diff --git a/multi-process/worker/src/main.rs b/multi-process/worker/src/main.rs new file mode 100644 index 0000000..f08d012 --- /dev/null +++ b/multi-process/worker/src/main.rs @@ -0,0 +1,72 @@ +mod cache; +mod command; +mod context; +mod dispatch; + +pub(crate) use self::{cache::Cache, context::CONTEXT}; + +use anyhow::Context as _; +use std::{env, pin::pin}; +use tokio::signal; +use tokio_websockets::{ClientBuilder, Limits}; +use tracing::{Instrument as _, instrument::Instrumented}; +use twilight_http::Client; +use twilight_model::{ + gateway::event::Event, + id::{Id, marker::ApplicationMarker}, +}; + +#[rustfmt::skip] +const APPLICATION_ID: Id = Id::new({{application_id}}); + +#[tokio::main(flavor = "current_thread")] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::init(); + + let token = env::var("TOKEN").context("failed to get `TOKEN`")?; + + let http = Client::new(token.clone()); + http.interaction(APPLICATION_ID) + .set_global_commands(&command::commands()) + .await + .context("failed to put commands")?; + context::init(Cache::new(), http); + + let task = tokio::spawn({ + let (socket, _) = ClientBuilder::new() + .limits(Limits::unlimited()) + .uri("ws://[::1]:3000/events")? + .connect() + .await?; + dispatch::run(event_handler, socket) + }); + + signal::ctrl_c().await?; + tracing::info!("shutting down; press CTRL-C to abort"); + + tokio::select! { + _ = signal::ctrl_c() => {}, + result = task => result?, + } + + Ok(()) +} + +async fn event_handler(event: Event) { + async fn log_err(future: Instrumented>>) { + let mut future = pin!(future); + if let Err(error) = future.as_mut().await { + let _enter = future.span().enter(); + tracing::warn!(error = &*error, "failed to handle event"); + } + } + + #[allow(clippy::single_match)] + match event { + Event::InteractionCreate(event) => { + let span = tracing::info_span!("interaction", id = %event.id); + log_err(command::interaction(event).instrument(span)).await; + } + _ => {} + } +} diff --git a/single-process/.gitignore b/single-process/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/single-process/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.toml b/single-process/Cargo.toml similarity index 100% rename from Cargo.toml rename to single-process/Cargo.toml diff --git a/single-process/cargo-generate.toml b/single-process/cargo-generate.toml new file mode 100644 index 0000000..42bb406 --- /dev/null +++ b/single-process/cargo-generate.toml @@ -0,0 +1,6 @@ +[template] +ignore = ["Cargo.lock", "target"] + +[placeholders] +admin_guild_id = { prompt = "Enter the admin guild ID (where admin commands are available)", type = "string" } +application_id = { prompt = "Enter the application ID (available in the applications dashboard: )", type = "string" } diff --git a/src/command.rs b/single-process/src/command.rs similarity index 100% rename from src/command.rs rename to single-process/src/command.rs diff --git a/single-process/src/command/ping.rs b/single-process/src/command/ping.rs new file mode 100644 index 0000000..20933eb --- /dev/null +++ b/single-process/src/command/ping.rs @@ -0,0 +1,44 @@ +use crate::{APPLICATION_ID, CONTEXT}; +use twilight_model::{ + application::{ + command::{Command, CommandType}, + interaction::application_command::CommandData, + }, + channel::message::MessageFlags, + gateway::payload::incoming::InteractionCreate, + http::interaction::{InteractionResponse, InteractionResponseData, InteractionResponseType}, +}; +use twilight_util::builder::command::CommandBuilder; + +pub const NAME: &str = "ping"; + +pub fn command() -> Command { + CommandBuilder::new(NAME, "Ping the bot", CommandType::ChatInput).build() +} + +pub async fn autocomplete( + _event: Box, + _data: Box, +) -> anyhow::Result<()> { + Ok(()) +} + +pub async fn run(event: Box, _data: Box) -> anyhow::Result<()> { + let data = InteractionResponseData { + content: Some("Pong!".to_owned()), + flags: Some(MessageFlags::EPHEMERAL), + ..Default::default() + }; + + let response = InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some(data), + }; + CONTEXT + .http + .interaction(APPLICATION_ID) + .create_response(event.id, &event.token, &response) + .await?; + + Ok(()) +} diff --git a/src/command/restart.rs b/single-process/src/command/restart.rs similarity index 100% rename from src/command/restart.rs rename to single-process/src/command/restart.rs diff --git a/src/context.rs b/single-process/src/context.rs similarity index 100% rename from src/context.rs rename to single-process/src/context.rs diff --git a/src/dispatch.rs b/single-process/src/dispatch.rs similarity index 100% rename from src/dispatch.rs rename to single-process/src/dispatch.rs diff --git a/src/main.rs b/single-process/src/main.rs similarity index 100% rename from src/main.rs rename to single-process/src/main.rs diff --git a/single-process/src/resume.rs b/single-process/src/resume.rs new file mode 100644 index 0000000..7f77538 --- /dev/null +++ b/single-process/src/resume.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; +use tokio::fs; +use twilight_gateway::{Config, ConfigBuilder, Session, Shard, ShardId}; + +const INFO_FILE: &str = "resume-info.json"; + +pub trait ConfigBuilderExt { + fn resume_info(self, resume_info: Info) -> Self; +} + +impl ConfigBuilderExt for ConfigBuilder { + fn resume_info(mut self, resume_info: Info) -> Self { + if let Some(resume_url) = resume_info.resume_url { + self = self.resume_url(resume_url); + } + if let Some(session) = resume_info.session { + self = self.session(session); + } + + self + } +} + +/// [`Shard`] session resumption information. +#[derive(Debug, Deserialize, Serialize)] +pub struct Info { + resume_url: Option, + session: Option, +} + +impl Info { + fn is_none(&self) -> bool { + self.resume_url.is_none() && self.session.is_none() + } +} + +impl From<&Shard> for Info { + fn from(value: &Shard) -> Self { + Self { + resume_url: value.resume_url().map(ToOwned::to_owned), + session: value.session().cloned(), + } + } +} + +/// Saves shard resumption information to the file system. +pub async fn save(info: &[Info]) -> anyhow::Result<()> { + if !info.iter().all(Info::is_none) { + let contents = serde_json::to_vec(&info)?; + fs::write(INFO_FILE, contents).await?; + } + + Ok(()) +} + +/// Restores shard resumption information from the file system. +pub async fn restore(config: Config, shards: u32) -> Vec { + let info = async { + let contents = fs::read(INFO_FILE).await?; + Ok::<_, anyhow::Error>(serde_json::from_slice::>(&contents)?) + } + .await; + + let shard_ids = (0..shards).map(|shard| ShardId::new(shard, shards)); + + // A session may only be successfully resumed if it retains its shard ID, but + // Discord may have recommend a different shard count (producing different shard + // IDs). + let shards: Vec<_> = if let Ok(info) = info + && info.len() == shards as usize + { + tracing::info!("resuming previous gateway sessions"); + shard_ids + .zip(info) + .map(|(shard_id, info)| { + let builder = ConfigBuilder::from(config.clone()).resume_info(info); + Shard::with_config(shard_id, builder.build()) + }) + .collect() + } else { + shard_ids + .map(|shard_id| Shard::with_config(shard_id, config.clone())) + .collect() + }; + + // Resumed or not, the saved resume info is now stale. + _ = fs::remove_file(INFO_FILE).await; + + shards +}