Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cargo-generate.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ init = ["init-script.rhai"]

[placeholders]
admin_guild_id = { prompt = "Enter the admin guild ID (where admin commands are available)", type = "string" }
application_id = { prompt = "Enter the application ID (available in the applications dashboard: <https://discord.com/developers/applications>)", type = "string" }
16 changes: 8 additions & 8 deletions src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,28 @@ impl From<&str> for Kind {
}
}

pub async fn interaction(mut event: Box<InteractionCreate>) -> anyhow::Result<()> {
match event.kind {
pub async fn handler(mut interaction: Box<InteractionCreate>) -> anyhow::Result<()> {
match interaction.kind {
InteractionType::ApplicationCommandAutocomplete => {
let InteractionData::ApplicationCommand(data) = event.data.take().unwrap() else {
let InteractionData::ApplicationCommand(data) = interaction.data.take().unwrap() else {
unreachable!();
};
let kind = data.name.as_str().into();

match kind {
Kind::Ping => ping::autocomplete(event, data).await?,
Kind::Restart => restart::autocomplete(event, data).await?,
Kind::Ping => ping::autocomplete(interaction, data).await?,
Kind::Restart => restart::autocomplete(interaction, data).await?,
}
}
InteractionType::ApplicationCommand => {
let InteractionData::ApplicationCommand(data) = event.data.take().unwrap() else {
let InteractionData::ApplicationCommand(data) = interaction.data.take().unwrap() else {
unreachable!();
};
let kind = data.name.as_str().into();

match kind {
Kind::Ping => ping::run(event, data).await?,
Kind::Restart => restart::run(event, data).await?,
Kind::Ping => ping::run(interaction, data).await?,
Kind::Restart => restart::run(interaction, data).await?,
}
}
_ => {}
Expand Down
15 changes: 8 additions & 7 deletions src/command/ping.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{APPLICATION_ID, CONTEXT};
use crate::CTX;
use twilight_model::{
application::{
command::{Command, CommandType},
Expand All @@ -17,13 +17,16 @@ pub fn command() -> Command {
}

pub async fn autocomplete(
_event: Box<InteractionCreate>,
_interaction: Box<InteractionCreate>,
_data: Box<CommandData>,
) -> anyhow::Result<()> {
Ok(())
}

pub async fn run(event: Box<InteractionCreate>, _data: Box<CommandData>) -> anyhow::Result<()> {
pub async fn run(
interaction: Box<InteractionCreate>,
_data: Box<CommandData>,
) -> anyhow::Result<()> {
let data = InteractionResponseData {
content: Some("Pong!".to_owned()),
flags: Some(MessageFlags::EPHEMERAL),
Expand All @@ -34,10 +37,8 @@ pub async fn run(event: Box<InteractionCreate>, _data: Box<CommandData>) -> anyh
kind: InteractionResponseType::ChannelMessageWithSource,
data: Some(data),
};
CONTEXT
.http
.interaction(APPLICATION_ID)
.create_response(event.id, &event.token, &response)
CTX.interaction()
.create_response(interaction.id, &interaction.token, &response)
.await?;

Ok(())
Expand Down
51 changes: 19 additions & 32 deletions src/command/restart.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{APPLICATION_ID, CONTEXT, ShardRestartKind};
use crate::{CTX, ShardRestartKind};
use std::iter;
use twilight_model::{
application::{
Expand Down Expand Up @@ -31,7 +31,7 @@ pub fn command(shards: u32) -> Command {
}

pub async fn autocomplete(
event: Box<InteractionCreate>,
interaction: Box<InteractionCreate>,
mut data: Box<CommandData>,
) -> anyhow::Result<()> {
let choice = |shard_id: u32| CommandOptionChoice {
Expand All @@ -45,15 +45,13 @@ pub async fn autocomplete(
unreachable!()
};

let choices: Vec<_> = match value.parse() {
let choices = match value.parse() {
Ok(shard_id) if shard_id == 0 => vec![choice(shard_id)],
Ok(shard_id) => starts_with(shard_id, CONTEXT.shard_handles.len() as u32)
Ok(shard_id) => starts_with(shard_id, CTX.shards.len() as u32)
.take(25)
.map(choice)
.collect(),
Err(_) => (0..25.min(CONTEXT.shard_handles.len() as u32))
.map(choice)
.collect(),
Err(_) => (0..25.min(CTX.shards.len() as u32)).map(choice).collect(),
};
let data = InteractionResponseData {
choices: Some(choices),
Expand All @@ -64,16 +62,17 @@ pub async fn autocomplete(
kind: InteractionResponseType::ApplicationCommandAutocompleteResult,
data: Some(data),
};
CONTEXT
.http
.interaction(APPLICATION_ID)
.create_response(event.id, &event.token, &response)
CTX.interaction()
.create_response(interaction.id, &interaction.token, &response)
.await?;

Ok(())
}

pub async fn run(event: Box<InteractionCreate>, mut data: Box<CommandData>) -> anyhow::Result<()> {
pub async fn run(
interaction: Box<InteractionCreate>,
mut data: Box<CommandData>,
) -> anyhow::Result<()> {
let mut options = data.options.drain(..);
let CommandOptionValue::Integer(shard_id) = options.next().unwrap().value else {
unreachable!()
Expand All @@ -86,12 +85,8 @@ pub async fn run(event: Box<InteractionCreate>, mut data: Box<CommandData>) -> a
_ => ShardRestartKind::Normal,
};

let shard_handle = CONTEXT
.shard_handles
.get(&(shard_id as u32))
.unwrap()
.clone();
let restart_result = shard_handle.restart(kind);
let shard = CTX.shards.get(&(shard_id as u32)).unwrap().clone();
let restart_result = shard.restart(kind);

let response = if restart_result.is_forced() {
tracing::debug!(shard.id = shard_id, "force restarting shard");
Expand All @@ -110,31 +105,23 @@ pub async fn run(event: Box<InteractionCreate>, mut data: Box<CommandData>) -> a
data: None,
}
};
CONTEXT
.http
.interaction(APPLICATION_ID)
.create_response(event.id, &event.token, &response)
CTX.interaction()
.create_response(interaction.id, &interaction.token, &response)
.await?;
if restart_result.is_forced() {
return Ok(());
}

shard_handle.restarted().await;
let is_restarted = CONTEXT
.shard_handles
.get(&(shard_id as u32))
.unwrap()
.is_valid();
shard.restarted().await;
let is_restarted = CTX.shards.get(&(shard_id as u32)).unwrap().is_valid();

let content = if is_restarted {
"Shard restarted"
} else {
"Bot shut down"
};
CONTEXT
.http
.interaction(APPLICATION_ID)
.update_response(&event.token)
CTX.interaction()
.update_response(&interaction.token)
.content(Some(content))
.await?;

Expand Down
25 changes: 19 additions & 6 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
use crate::ShardHandle;
use dashmap::DashMap;
use std::{ops::Deref, sync::OnceLock};
use twilight_http::Client;
use twilight_http::{Client, client::InteractionClient};
use twilight_model::id::{Id, marker::ApplicationMarker};

pub static CONTEXT: Ref = Ref(OnceLock::new());
pub static CTX: Ref = Ref(OnceLock::new());

#[derive(Debug)]
pub struct Context {
pub application_id: Id<ApplicationMarker>,
pub http: Client,
pub shard_handles: DashMap<u32, ShardHandle>,
pub shards: DashMap<u32, ShardHandle>,
}

pub fn initialize(http: Client, shard_handles: DashMap<u32, ShardHandle>) {
impl Context {
pub fn interaction(&self) -> InteractionClient<'_> {
self.http.interaction(self.application_id)
}
}

pub fn init(
application_id: Id<ApplicationMarker>,
http: Client,
shards: DashMap<u32, ShardHandle>,
) {
let context = Context {
application_id,
http,
shard_handles,
shards,
};
assert!(CONTEXT.0.set(context).is_ok());
assert!(CTX.0.set(context).is_ok());
}

pub struct Ref(OnceLock<Context>);
Expand Down
6 changes: 2 additions & 4 deletions src/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{CONTEXT, ConfigBuilderExt as _, EVENT_TYPES, ResumeInfo};
use crate::{CTX, ConfigBuilderExt as _, EVENT_TYPES, ResumeInfo};
use std::{error::Error, pin::pin};
use tokio::{signal, sync::watch};
use tokio_util::task::TaskTracker;
Expand All @@ -23,9 +23,7 @@ pub struct ShardHandle(watch::Sender<Option<ShardRestartKind>>);
impl ShardHandle {
fn insert(shard_id: ShardId) -> watch::Receiver<Option<ShardRestartKind>> {
let (tx, rx) = watch::channel(None);
CONTEXT
.shard_handles
.insert(shard_id.number(), ShardHandle(tx));
CTX.shards.insert(shard_id.number(), ShardHandle(tx));

rx
}
Expand Down
36 changes: 17 additions & 19 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ mod context;
mod dispatch;
mod resume;

pub(crate) use self::{
context::CONTEXT,
pub use self::{
context::CTX,
dispatch::{ShardHandle, ShardRestartKind},
resume::{ConfigBuilderExt, Info as ResumeInfo},
};
Expand All @@ -16,15 +16,10 @@ use tokio::signal;
use tracing::{Instrument as _, instrument::Instrumented};
use twilight_gateway::{ConfigBuilder, Event, EventTypeFlags, Intents, queue::InMemoryQueue};
use twilight_http::Client;
use twilight_model::id::{
Id,
marker::{ApplicationMarker, GuildMarker},
};
use twilight_model::id::{Id, marker::GuildMarker};

#[rustfmt::skip]
const ADMIN_GUILD_ID: Id<GuildMarker> = Id::new({{admin_guild_id}});
#[rustfmt::skip]
const APPLICATION_ID: Id<ApplicationMarker> = Id::new({{application_id}});
const EVENT_TYPES: EventTypeFlags = EventTypeFlags::INTERACTION_CREATE;
const INTENTS: Intents = Intents::empty();

Expand All @@ -35,22 +30,25 @@ async fn main() -> anyhow::Result<()> {
let token = env::var("TOKEN").context("reading `TOKEN`")?;

let http = Client::new(token.clone());
let info = async { Ok::<_, anyhow::Error>(http.gateway().authed().await?.model().await?) }
let app = async { anyhow::Ok(http.current_user_application().await?.model().await?) }
.await
.context("getting app")?;
let info = async { anyhow::Ok(http.gateway().authed().await?.model().await?) }
.await
.context("getting info")?;
async {
http.interaction(APPLICATION_ID)
http.interaction(app.id)
.set_global_commands(&command::global_commands())
.await?;
http.interaction(APPLICATION_ID)
http.interaction(app.id)
.set_guild_commands(ADMIN_GUILD_ID, &command::admin_commands(info.shards))
.await?;
Ok::<_, anyhow::Error>(())
anyhow::Ok(())
}
.await
.context("putting commands")?;
let shard_handles = DashMap::new();
context::initialize(http, shard_handles);
let shards = DashMap::new();
context::init(app.id, http, shards);

// The queue defaults are static and may be incorrect for large or newly
// restarted bots.
Expand All @@ -73,11 +71,11 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("shutting down; press CTRL-C to abort");

let join_all_tasks = async {
let mut resume_info = Vec::new();
let mut resume_info = Vec::with_capacity(tasks.len());
for task in tasks {
resume_info.push(task.await?);
}
Ok::<_, anyhow::Error>(resume_info)
anyhow::Ok(resume_info)
};
let resume_info = tokio::select! {
_ = signal::ctrl_c() => Vec::new(),
Expand All @@ -103,9 +101,9 @@ async fn event_handler(event: Event, _state: ()) {

#[allow(clippy::single_match)]
match event {
Event::InteractionCreate(event) => {
let span = tracing::info_span!("interaction", id = %event.id);
log_err(command::interaction(event).instrument(span)).await;
Event::InteractionCreate(interaction) => {
let span = tracing::info_span!("interaction", id = %interaction.id);
log_err(command::handler(interaction).instrument(span)).await;
}
_ => {}
}
Expand Down
2 changes: 1 addition & 1 deletion src/resume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub async fn save(info: &[Info]) -> anyhow::Result<()> {
pub async fn restore(config: Config, shards: u32) -> Vec<Shard> {
let info = async {
let contents = fs::read(INFO_FILE).await?;
Ok::<_, anyhow::Error>(serde_json::from_slice::<Vec<Info>>(&contents)?)
anyhow::Ok(serde_json::from_slice::<Vec<Info>>(&contents)?)
}
.await;

Expand Down