Skip to content

Commit

Permalink
Refactoring + make DB Enums more concise with proc macros
Browse files Browse the repository at this point in the history
  • Loading branch information
1Git2Clone committed Jan 3, 2025
1 parent 7ec7623 commit 9cf15f4
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 144 deletions.
91 changes: 32 additions & 59 deletions serenity_discord_bot_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
extern crate proc_macro;
mod utils;

use proc_macro::TokenStream;
use quote::quote;
use utils::{
get_variant_str_values_by_name, impl_display, impl_display_with_vals,
string_manipulation::pascal_to_snake_case,
};

// Credit:
// https://stackoverflow.com/questions/68025264/how-to-get-all-the-variants-of-an-enum-in-a-vect-with-a-proc-macro/69812881#69812881
Expand Down Expand Up @@ -29,60 +34,6 @@ pub fn derive_all_variants(input: TokenStream) -> TokenStream {
.into()
}

fn impl_display<'a>(
enum_name: syn::Ident,
variant_idents: Vec<&'a syn::Ident>,
variants_values: Vec<String>,
display_pat: fn(ident: &'a syn::Ident, val: &str) -> String,
) -> proc_macro2::TokenStream {
let display_iter = variant_idents
.iter()
.zip(variants_values.iter())
.map(|(i, v)| display_pat(i, v));

quote! {
impl std::fmt::Display for #enum_name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#(Self::#variant_idents => {
write!(
f,
#display_iter
)
})*
}
}
}
}
}

fn get_variant_str_values_by_name(enum_item: syn::DataEnum, name: &str) -> Vec<String> {
enum_item
.variants
.iter()
.filter_map(|v| {
if !v.attrs.iter().any(|attr| attr.path().is_ident(name)) {
return None;
}

v.attrs
.iter()
.find(|attr| attr.path().is_ident(name))
.map(|attr| match &attr.meta {
syn::Meta::NameValue(nv) => match &nv.value {
syn::Expr::Lit(lit_expr) => match &lit_expr.lit {
syn::Lit::Str(str) => Some(str.value()),
_ => None,
},
_ => None,
},
_ => None,
})
})
.map(|x| x.unwrap())
.collect()
}

#[proc_macro_derive(DiscordEmoji, attributes(emoji_id))]
pub fn derive_discord_emoji(input: TokenStream) -> TokenStream {
let ast = syn::parse_macro_input!(input as syn::DeriveInput);
Expand All @@ -99,7 +50,7 @@ pub fn derive_discord_emoji(input: TokenStream) -> TokenStream {
.collect::<Vec<_>>();
let variants_ids = get_variant_str_values_by_name(enum_item.clone(), "emoji_id");

let display_quote = impl_display(
let display_quote = impl_display_with_vals(
enum_name.clone(),
variant_idents.clone(),
variants_ids.clone(),
Expand Down Expand Up @@ -141,9 +92,6 @@ pub fn derive_discord_emoji(input: TokenStream) -> TokenStream {
.into()
}

// TODO: Simplify the process of making these single attribute derive macros due to the current
// code duplication

#[proc_macro_derive(Asset, attributes(src_path))]
pub fn derive_asset(input: TokenStream) -> TokenStream {
let ast = syn::parse_macro_input!(input as syn::DeriveInput);
Expand All @@ -160,10 +108,35 @@ pub fn derive_asset(input: TokenStream) -> TokenStream {
.collect::<Vec<_>>();
let variants_values = get_variant_str_values_by_name(enum_item.clone(), "src_path");

impl_display(enum_name, variant_idents, variants_values, |_ident, src_path| {
impl_display_with_vals(enum_name, variant_idents, variants_values, |_ident, src_path| {
format!(
"https://raw.githubusercontent.com/1Git2Clone/serenity-discord-bot/main/src/assets/{src_path}",
)
})
.into()
}

/// Implements `std::fmt::Display` for the enum by converting all the `PascalCase` variants to
/// `snake_case`.
///
/// NOTE: Also adds a `.as_str()` method.
#[proc_macro_derive(DatabaseEnum)]
pub fn derive_database_enum(input: TokenStream) -> TokenStream {
let ast = syn::parse_macro_input!(input as syn::DeriveInput);

let syn::Data::Enum(enum_item) = ast.data else {
return quote!(compile_error!("Only works on enums")).into();
};

let enum_name = ast.ident;
let variant_idents = enum_item
.variants
.iter()
.map(|v| &v.ident)
.collect::<Vec<_>>();

impl_display(enum_name, variant_idents, |ident| {
pascal_to_snake_case(&ident.to_string())
})
.into()
}
101 changes: 101 additions & 0 deletions serenity_discord_bot_derive/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
pub mod string_manipulation;

use quote::quote;

pub fn quote_display_impl(
enum_name: syn::Ident,
variant_idents: Vec<&syn::Ident>,
display_results: &[String],
) -> proc_macro2::TokenStream {
let display = {
let iter = display_results.iter();
quote! {
impl std::fmt::Display for #enum_name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#(Self::#variant_idents => {
write!(
f,
#iter
)
})*
}
}
}
}
};
let as_str = {
let iter = display_results.iter();
quote! {
impl #enum_name {
pub fn as_str(&self) -> &'static str {
match self {
#(Self::#variant_idents => {
#iter
})*
}
}
}
}
};

quote! {
#display
#as_str
}
}

pub fn impl_display<'a>(
enum_name: syn::Ident,
variant_idents: Vec<&'a syn::Ident>,
display_pat: fn(ident: &'a syn::Ident) -> String,
) -> proc_macro2::TokenStream {
let res = variant_idents
.iter()
.map(|i| display_pat(i))
.collect::<Vec<_>>();

quote_display_impl(enum_name, variant_idents, &res)
}

pub fn impl_display_with_vals<'a>(
enum_name: syn::Ident,
variant_idents: Vec<&'a syn::Ident>,
variants_values: Vec<String>,
display_pat: fn(ident: &'a syn::Ident, val: &str) -> String,
) -> proc_macro2::TokenStream {
let res = variant_idents
.iter()
.zip(variants_values.iter())
.map(|(i, v)| display_pat(i, v))
.collect::<Vec<_>>();

quote_display_impl(enum_name, variant_idents, &res)
}

pub fn get_variant_str_values_by_name(enum_item: syn::DataEnum, name: &str) -> Vec<String> {
enum_item
.variants
.iter()
.filter_map(|v| {
if !v.attrs.iter().any(|attr| attr.path().is_ident(name)) {
return None;
}

v.attrs
.iter()
.find(|attr| attr.path().is_ident(name))
.map(|attr| match &attr.meta {
syn::Meta::NameValue(nv) => match &nv.value {
syn::Expr::Lit(lit_expr) => match &lit_expr.lit {
syn::Lit::Str(str) => Some(str.value()),
_ => None,
},
_ => None,
},
_ => None,
})
})
.map(|x| x.unwrap())
.collect()
}
15 changes: 15 additions & 0 deletions serenity_discord_bot_derive/src/utils/string_manipulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn pascal_to_snake_case(s: &str) -> String {
let mut res = String::with_capacity(s.len());

let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
res.push_str(&c.to_lowercase().to_string());
if let Some(next) = chars.peek() {
if next.is_uppercase() {
res.push('_');
}
}
}

res
}
4 changes: 2 additions & 2 deletions src/commands/level_cmds/level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ pub async fn level(
};
let level = level_xp_and_rank_row
.1
.get::<i32, &str>(LEVELS_TABLE[&LevelsSchema::Level]);
.get::<i32, &str>(LevelsSchema::Level.as_str());
let xp = level_xp_and_rank_row
.1
.get::<i32, &str>(LEVELS_TABLE[&LevelsSchema::ExperiencePoints]);
.get::<i32, &str>(LevelsSchema::ExperiencePoints.as_str());

let avatar = target_replied_user.face().replace(".webp", ".png");
let username = &target_replied_user.name;
Expand Down
6 changes: 3 additions & 3 deletions src/commands/level_cmds/toplevels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub async fn toplevels(ctx: Context<'_>) -> Result<(), Error> {
ctx.defer().await?;
let user_ids: Vec<u64> = level_and_xp_rows
.par_iter()
.map(|row| row.get::<i64, &str>(LEVELS_TABLE[&LevelsSchema::UserId]) as u64)
.map(|row| row.get::<i64, &str>(LevelsSchema::UserId.as_str()) as u64)
.collect();
let users = try_join_all(
user_ids
Expand All @@ -46,8 +46,8 @@ pub async fn toplevels(ctx: Context<'_>) -> Result<(), Error> {
.enumerate()
{
let (level, xp) = (
row.get::<u32, &str>(LEVELS_TABLE[&LevelsSchema::Level]),
row.get::<u32, &str>(LEVELS_TABLE[&LevelsSchema::ExperiencePoints]),
row.get::<u32, &str>(LevelsSchema::Level.as_str()),
row.get::<u32, &str>(LevelsSchema::ExperiencePoints.as_str()),
);
let xp_to_level_up = calculate_xp_to_level_up(level);

Expand Down
16 changes: 9 additions & 7 deletions src/commands/level_logic.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::prelude::*;
pub struct LevelStats {
pub updated_level: u32,
pub updated_experience: u32,
}

pub fn calculate_xp_to_level_up(level: u32) -> u32 {
level * 100
}

/// Set the leveling condition and return the updated level with reset xp if true.
pub async fn update_level(experience: u32, level: u32) -> HashMap<LevelsSchema, u32> {
use crate::enums::schemas::LevelsSchema as DbSch;
pub async fn update_level(experience: u32, level: u32) -> LevelStats {
let update_level = if experience >= calculate_xp_to_level_up(level) {
level + 1
} else {
Expand All @@ -15,8 +17,8 @@ pub async fn update_level(experience: u32, level: u32) -> HashMap<LevelsSchema,

let update_experience = if update_level == level { experience } else { 0 };

HashMap::from([
(DbSch::ExperiencePoints, update_experience),
(DbSch::Level, update_level),
])
LevelStats {
updated_level: update_level,
updated_experience: update_experience,
}
}
Loading

0 comments on commit 9cf15f4

Please sign in to comment.