diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..da01d41ea 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool + | scalar::Ty::SmallString | scalar::Ty::String => Ok(ty), _ => bail!("invalid map key type: {}", s), } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index c2e870524..618f579f6 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -116,6 +116,13 @@ impl Field { let tag = self.tag; match self.kind { + Kind::Plain(DefaultValue::SmallString) => { + quote! { + if !#ident.is_empty(){ + #encode_fn(#tag, &#ident, buf); + } + } + } Kind::Plain(ref default) => { let default = default.typed(); quote! { @@ -170,6 +177,15 @@ impl Field { let tag = self.tag; match self.kind { + Kind::Plain(DefaultValue::SmallString) => { + quote! { + if !#ident.is_empty() { + #encoded_len_fn(#tag, &#ident) + } else { + 0 + } + } + } Kind::Plain(ref default) => { let default = default.typed(); quote! { @@ -194,7 +210,7 @@ impl Field { Kind::Plain(ref default) | Kind::Required(ref default) => { let default = default.typed(); match self.ty { - Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + Ty::String | Ty::SmallString | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), } } @@ -398,6 +414,7 @@ pub enum Ty { Sfixed64, Bool, String, + SmallString, Bytes(BytesTy), Enumeration(Path), } @@ -442,6 +459,7 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("smallstring") => Ty::SmallString, Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), Meta::NameValue(MetaNameValue { ref path, @@ -487,6 +505,7 @@ impl Ty { "sfixed64" => Ty::Sfixed64, "bool" => Ty::Bool, "string" => Ty::String, + "smallstring" => Ty::SmallString, "bytes" => Ty::Bytes(BytesTy::Vec), s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { let s = &s[enumeration_len..].trim(); @@ -523,6 +542,7 @@ impl Ty { Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", Ty::String => "string", + Ty::SmallString => "smallstring", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -532,6 +552,7 @@ impl Ty { pub fn rust_type(&self) -> TokenStream { match self { Ty::String => quote!(::prost::alloc::string::String), + Ty::SmallString => quote!(::compact_str::CompactString), Ty::Bytes(ty) => ty.rust_type(), _ => self.rust_ref_type(), } @@ -554,6 +575,7 @@ impl Ty { Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), Ty::String => quote!(&str), + Ty::SmallString => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } @@ -568,7 +590,7 @@ impl Ty { /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::String | Ty::SmallString | Ty::Bytes(..)) } } @@ -610,6 +632,7 @@ pub enum DefaultValue { U64(u64), Bool(bool), String(String), + SmallString, Bytes(Vec), Enumeration(TokenStream), Path(Path), @@ -774,6 +797,7 @@ impl DefaultValue { Ty::Bool => DefaultValue::Bool(false), Ty::String => DefaultValue::String(String::new()), + Ty::SmallString => DefaultValue::SmallString, Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } @@ -785,6 +809,9 @@ impl DefaultValue { quote!(::prost::alloc::string::String::new()) } DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::SmallString => { + quote!(::compact_str::CompactString::default()) + } DefaultValue::Bytes(ref value) if value.is_empty() => { quote!(::core::default::Default::default()) } @@ -800,6 +827,8 @@ impl DefaultValue { pub fn typed(&self) -> TokenStream { if let DefaultValue::Enumeration(_) = *self { quote!(#self as i32) + } else if let DefaultValue::SmallString = *self { + quote!(Default::default()) } else { quote!(#self) } @@ -817,6 +846,7 @@ impl ToTokens for DefaultValue { DefaultValue::U64(value) => value.to_tokens(tokens), DefaultValue::Bool(value) => value.to_tokens(tokens), DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::SmallString => "".to_tokens(tokens), DefaultValue::Bytes(ref value) => { let byte_str = LitByteStr::new(value, Span::call_site()); tokens.append_all(quote!(#byte_str as &[u8])); diff --git a/prost/Cargo.toml b/prost/Cargo.toml index cc4e52689..28310c2ba 100644 --- a/prost/Cargo.toml +++ b/prost/Cargo.toml @@ -21,10 +21,12 @@ derive = ["dep:prost-derive"] prost-derive = ["derive"] # deprecated, please use derive feature instead no-recursion-limit = [] std = [] +smallstring = ["compact_str"] [dependencies] bytes = { version = "1", default-features = false } prost-derive = { version = "0.12.6", path = "../prost-derive", optional = true } +compact_str = { version = "0.8.0-beta", optional = true } [dev-dependencies] criterion = { version = "0.5", default-features = false } diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 88f65e643..305abf536 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -832,6 +832,76 @@ pub mod string { } } +#[cfg(feature = "smallstring")] +pub mod smallstring { + use super::*; + use core::mem::ManuallyDrop; + + pub fn encode(tag: u32, value: &str, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + buf.put_slice(value.as_bytes()); + } + pub fn merge( + wire_type: WireType, + value: &mut compact_str::CompactString, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + unsafe { + let mut fake_vec = ManuallyDrop::new(Vec::from_raw_parts( + value.as_mut_ptr(), + value.len(), + value.capacity(), + )); + + bytes::merge_one_copy(wire_type, &mut *fake_vec, buf, ctx)?; + match compact_str::CompactString::from_utf8(&*fake_vec) { + Ok(s) => { + *value = s; + Ok(()) + } + Err(_) => Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )), + } + } + } + + length_delimited!(compact_str::CompactString); + + #[cfg(test)] + mod test { + use compact_str::{CompactString, ToCompactString}; + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: String, tag in MIN_TAG..=MAX_TAG) { + let value = value.to_compact_string(); + super::test::check_type(value, tag, WireType::LengthDelimited, + encode, merge, |tag, s| encoded_len(tag, &s.to_compact_string()))?; + } + #[test] + fn check_repeated(value: Vec, tag in MIN_TAG..=MAX_TAG) { + let value = value.into_iter().map(CompactString::from).collect(); + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + pub trait BytesAdapter: sealed::BytesAdapter {} mod sealed {