|
| 1 | +// Copyright (c) Sony Pictures Imageworks, et al. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// https://github.com/imageworks/spk |
| 4 | + |
| 5 | +use proc_macro::TokenStream; |
| 6 | +use quote::{format_ident, quote}; |
| 7 | +use syn::LitStr; |
| 8 | + |
| 9 | +/// Derive macro for generating boilerplate `Default` and `Drop` impls |
| 10 | +/// for a struct with `indicatif::ProgressBar` fields. |
| 11 | +/// |
| 12 | +/// The struct is required to have a field named `renderer` and one or more |
| 13 | +/// fields of type `indicatif::ProgressBar`. Each progress bar field requires |
| 14 | +/// a `#[progress_bar]` attribute with a `message` argument. A `template` |
| 15 | +/// argument is also required either at the struct level or the field level. |
| 16 | +/// |
| 17 | +/// # Example |
| 18 | +/// |
| 19 | +/// ``` |
| 20 | +/// use progress_bar_derive_macro::ProgressBar; |
| 21 | +/// #[derive(ProgressBar)] |
| 22 | +/// struct MyStruct { |
| 23 | +/// renderer: Option<std::thread::JoinHandle<()>>, |
| 24 | +/// #[progress_bar( |
| 25 | +/// message = "processing widgets", |
| 26 | +/// template = " {spinner} {msg:<16.green} [{bar:40.cyan/dim}] {pos:>8}/{len:6}" |
| 27 | +/// )] |
| 28 | +/// widgets: indicatif::ProgressBar, |
| 29 | +/// } |
| 30 | +/// ``` |
| 31 | +#[proc_macro_derive(ProgressBar, attributes(progress_bar))] |
| 32 | +pub fn proc_macro_derive(input: TokenStream) -> TokenStream { |
| 33 | + let ast = syn::parse(input).unwrap(); |
| 34 | + impl_proc_macro_derive(&ast) |
| 35 | +} |
| 36 | + |
| 37 | +fn impl_proc_macro_derive(ast: &syn::DeriveInput) -> TokenStream { |
| 38 | + let name = &ast.ident; |
| 39 | + |
| 40 | + let mut progress_bar_field_names = Vec::new(); |
| 41 | + let mut bars = Vec::new(); |
| 42 | + |
| 43 | + if let syn::Data::Struct(s) = &ast.data { |
| 44 | + let mut template = None; |
| 45 | + |
| 46 | + for attr in &ast.attrs { |
| 47 | + if !attr.path().is_ident("progress_bar") { |
| 48 | + continue; |
| 49 | + } |
| 50 | + |
| 51 | + if let Err(err) = attr.parse_nested_meta(|meta| { |
| 52 | + if meta.path.is_ident("template") { |
| 53 | + let value = meta.value()?; |
| 54 | + let s: LitStr = value.parse()?; |
| 55 | + template = Some(s.value()); |
| 56 | + return Ok(()); |
| 57 | + } |
| 58 | + Ok(()) |
| 59 | + }) { |
| 60 | + return err.to_compile_error().into(); |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + for field in &s.fields { |
| 65 | + let Some(ident) = &field.ident else { continue; }; |
| 66 | + if let syn::Type::Path(p) = &field.ty { |
| 67 | + if let Some(field_type) = p.path.segments.last().map(|s| &s.ident) { |
| 68 | + if field_type != "ProgressBar" { |
| 69 | + continue; |
| 70 | + } |
| 71 | + |
| 72 | + let mut message = None; |
| 73 | + |
| 74 | + for attr in &field.attrs { |
| 75 | + if !attr.path().is_ident("progress_bar") { |
| 76 | + continue; |
| 77 | + } |
| 78 | + |
| 79 | + if let Err(err) = attr.parse_nested_meta(|meta| { |
| 80 | + if meta.path.is_ident("message") { |
| 81 | + let value = meta.value()?; |
| 82 | + let s: LitStr = value.parse()?; |
| 83 | + message = Some(s.value()); |
| 84 | + return Ok(()); |
| 85 | + } |
| 86 | + if meta.path.is_ident("template") { |
| 87 | + let value = meta.value()?; |
| 88 | + let s: LitStr = value.parse()?; |
| 89 | + template = Some(s.value()); |
| 90 | + return Ok(()); |
| 91 | + } |
| 92 | + Ok(()) |
| 93 | + }) { |
| 94 | + return err.to_compile_error().into(); |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + let Some(message) = message else { |
| 99 | + return syn::Error::new_spanned( |
| 100 | + field, |
| 101 | + "Missing #[progress_bar(message = \"...\")] attribute", |
| 102 | + ) |
| 103 | + .to_compile_error() |
| 104 | + .into(); |
| 105 | + }; |
| 106 | + |
| 107 | + let Some(template) = &template else { |
| 108 | + return syn::Error::new_spanned( |
| 109 | + field, |
| 110 | + "Missing #[progress_bar(template = \"...\")] attribute", |
| 111 | + ) |
| 112 | + .to_compile_error() |
| 113 | + .into(); |
| 114 | + }; |
| 115 | + |
| 116 | + let ident_style = format_ident!("{ident}_style"); |
| 117 | + |
| 118 | + bars.push(quote! { |
| 119 | + let #ident_style = indicatif::ProgressStyle::default_bar() |
| 120 | + .template(#template) |
| 121 | + .tick_strings(TICK_STRINGS) |
| 122 | + .progress_chars(PROGRESS_CHARS); |
| 123 | + let #ident = bars.add( |
| 124 | + indicatif::ProgressBar::new(0) |
| 125 | + .with_style(#ident_style) |
| 126 | + .with_message(#message), |
| 127 | + ); |
| 128 | + }); |
| 129 | + |
| 130 | + progress_bar_field_names.push(quote! { #ident }); |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + }; |
| 135 | + |
| 136 | + let gen = quote! { |
| 137 | + impl Default for #name { |
| 138 | + fn default() -> Self { |
| 139 | + static TICK_STRINGS: &[&str] = &["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; |
| 140 | + static PROGRESS_CHARS: &str = "=>-"; |
| 141 | + |
| 142 | + let bars = indicatif::MultiProgress::new(); |
| 143 | + #(#bars)* |
| 144 | + #(#progress_bar_field_names.enable_steady_tick(100);)* |
| 145 | + // the progress bar must be awaited from some thread |
| 146 | + // or nothing will be shown in the terminal |
| 147 | + let renderer = Some(std::thread::spawn(move || { |
| 148 | + if let Err(err) = bars.join() { |
| 149 | + tracing::error!("Failed to render commit progress: {err}"); |
| 150 | + } |
| 151 | + })); |
| 152 | + Self { |
| 153 | + #(#progress_bar_field_names,)* |
| 154 | + renderer, |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + impl Drop for #name { |
| 160 | + fn drop(&mut self) { |
| 161 | + #(self.#progress_bar_field_names.finish_and_clear();)* |
| 162 | + if let Some(r) = self.renderer.take() { |
| 163 | + let _ = r.join(); |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | + }; |
| 168 | + gen.into() |
| 169 | +} |
0 commit comments