Skip to content

Commit 8a0dfcd

Browse files
committed
Add smallstring support with compact_str
1 parent baddf98 commit 8a0dfcd

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed

prost-derive/src/field/map.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
367367
| scalar::Ty::Sfixed32
368368
| scalar::Ty::Sfixed64
369369
| scalar::Ty::Bool
370+
| scalar::Ty::SmallString
370371
| scalar::Ty::String => Ok(ty),
371372
_ => bail!("invalid map key type: {}", s),
372373
}

prost-derive/src/field/scalar.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ impl Field {
115115
let tag = self.tag;
116116

117117
match self.kind {
118+
Kind::Plain(DefaultValue::SmallString) => {
119+
quote! {
120+
if !#ident.is_empty(){
121+
#encode_fn(#tag, &#ident, buf);
122+
}
123+
}
124+
}
118125
Kind::Plain(ref default) => {
119126
let default = default.typed();
120127
quote! {
@@ -169,6 +176,15 @@ impl Field {
169176
let tag = self.tag;
170177

171178
match self.kind {
179+
Kind::Plain(DefaultValue::SmallString) => {
180+
quote! {
181+
if !#ident.is_empty() {
182+
#encoded_len_fn(#tag, &#ident)
183+
} else {
184+
0
185+
}
186+
}
187+
}
172188
Kind::Plain(ref default) => {
173189
let default = default.typed();
174190
quote! {
@@ -193,7 +209,7 @@ impl Field {
193209
Kind::Plain(ref default) | Kind::Required(ref default) => {
194210
let default = default.typed();
195211
match self.ty {
196-
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
212+
Ty::String | Ty::SmallString | Ty::Bytes(..) => quote!(#ident.clear()),
197213
_ => quote!(#ident = #default),
198214
}
199215
}
@@ -397,6 +413,7 @@ pub enum Ty {
397413
Sfixed64,
398414
Bool,
399415
String,
416+
SmallString,
400417
Bytes(BytesTy),
401418
Enumeration(Path),
402419
}
@@ -441,6 +458,7 @@ impl Ty {
441458
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
442459
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
443460
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
461+
Meta::Path(ref name) if name.is_ident("smallstring") => Ty::SmallString,
444462
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
445463
Meta::NameValue(MetaNameValue {
446464
ref path,
@@ -486,6 +504,7 @@ impl Ty {
486504
"sfixed64" => Ty::Sfixed64,
487505
"bool" => Ty::Bool,
488506
"string" => Ty::String,
507+
"smallstring" => Ty::SmallString,
489508
"bytes" => Ty::Bytes(BytesTy::Vec),
490509
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
491510
let s = &s[enumeration_len..].trim();
@@ -522,6 +541,7 @@ impl Ty {
522541
Ty::Sfixed64 => "sfixed64",
523542
Ty::Bool => "bool",
524543
Ty::String => "string",
544+
Ty::SmallString => "smallstring",
525545
Ty::Bytes(..) => "bytes",
526546
Ty::Enumeration(..) => "enum",
527547
}
@@ -531,6 +551,7 @@ impl Ty {
531551
pub fn rust_type(&self) -> TokenStream {
532552
match self {
533553
Ty::String => quote!(::prost::alloc::string::String),
554+
Ty::SmallString => quote!(::compact_str::CompactString),
534555
Ty::Bytes(ty) => ty.rust_type(),
535556
_ => self.rust_ref_type(),
536557
}
@@ -553,6 +574,7 @@ impl Ty {
553574
Ty::Sfixed64 => quote!(i64),
554575
Ty::Bool => quote!(bool),
555576
Ty::String => quote!(&str),
577+
Ty::SmallString => quote!(&str),
556578
Ty::Bytes(..) => quote!(&[u8]),
557579
Ty::Enumeration(..) => quote!(i32),
558580
}
@@ -567,7 +589,7 @@ impl Ty {
567589

568590
/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
569591
pub fn is_numeric(&self) -> bool {
570-
!matches!(self, Ty::String | Ty::Bytes(..))
592+
!matches!(self, Ty::String | Ty::SmallString | Ty::Bytes(..))
571593
}
572594
}
573595

@@ -609,6 +631,7 @@ pub enum DefaultValue {
609631
U64(u64),
610632
Bool(bool),
611633
String(String),
634+
SmallString,
612635
Bytes(Vec<u8>),
613636
Enumeration(TokenStream),
614637
Path(Path),
@@ -773,6 +796,7 @@ impl DefaultValue {
773796

774797
Ty::Bool => DefaultValue::Bool(false),
775798
Ty::String => DefaultValue::String(String::new()),
799+
Ty::SmallString => DefaultValue::SmallString,
776800
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
777801
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
778802
}
@@ -784,6 +808,9 @@ impl DefaultValue {
784808
quote!(::prost::alloc::string::String::new())
785809
}
786810
DefaultValue::String(ref value) => quote!(#value.into()),
811+
DefaultValue::SmallString => {
812+
quote!(::compact_str::CompactString::default())
813+
}
787814
DefaultValue::Bytes(ref value) if value.is_empty() => {
788815
quote!(::core::default::Default::default())
789816
}
@@ -799,6 +826,8 @@ impl DefaultValue {
799826
pub fn typed(&self) -> TokenStream {
800827
if let DefaultValue::Enumeration(_) = *self {
801828
quote!(#self as i32)
829+
} else if let DefaultValue::SmallString = *self {
830+
quote!(Default::default())
802831
} else {
803832
quote!(#self)
804833
}
@@ -816,6 +845,7 @@ impl ToTokens for DefaultValue {
816845
DefaultValue::U64(value) => value.to_tokens(tokens),
817846
DefaultValue::Bool(value) => value.to_tokens(tokens),
818847
DefaultValue::String(ref value) => value.to_tokens(tokens),
848+
DefaultValue::SmallString => "".to_tokens(tokens),
819849
DefaultValue::Bytes(ref value) => {
820850
let byte_str = LitByteStr::new(value, Span::call_site());
821851
tokens.append_all(quote!(#byte_str as &[u8]));

prost/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ derive = ["dep:prost-derive"]
2727
prost-derive = ["derive"] # deprecated, please use derive feature instead
2828
no-recursion-limit = []
2929
std = []
30+
smallstring = ["compact_str"]
3031

3132
[dependencies]
3233
bytes = { version = "1", default-features = false }
3334
prost-derive = { version = "0.12.4", path = "../prost-derive", optional = true }
35+
compact_str = { version = "0.7.1", optional = true }
3436

3537
[dev-dependencies]
3638
criterion = { version = "0.4", default-features = false }

prost/src/encoding.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,76 @@ pub mod string {
873873
}
874874
}
875875

876+
#[cfg(feature = "smallstring")]
877+
pub mod smallstring {
878+
use super::*;
879+
use core::mem::ManuallyDrop;
880+
881+
pub fn encode<B>(tag: u32, value: &str, buf: &mut B)
882+
where
883+
B: BufMut,
884+
{
885+
encode_key(tag, WireType::LengthDelimited, buf);
886+
encode_varint(value.len() as u64, buf);
887+
buf.put_slice(value.as_bytes());
888+
}
889+
pub fn merge<B>(
890+
wire_type: WireType,
891+
value: &mut compact_str::CompactString,
892+
buf: &mut B,
893+
ctx: DecodeContext,
894+
) -> Result<(), DecodeError>
895+
where
896+
B: Buf,
897+
{
898+
unsafe {
899+
let mut fake_vec = ManuallyDrop::new(Vec::from_raw_parts(
900+
value.as_mut_ptr(),
901+
value.len(),
902+
value.capacity(),
903+
));
904+
905+
bytes::merge_one_copy(wire_type, &mut *fake_vec, buf, ctx)?;
906+
match compact_str::CompactString::from_utf8(&*fake_vec) {
907+
Ok(s) => {
908+
*value = s;
909+
Ok(())
910+
}
911+
Err(_) => Err(DecodeError::new(
912+
"invalid string value: data is not UTF-8 encoded",
913+
)),
914+
}
915+
}
916+
}
917+
918+
length_delimited!(compact_str::CompactString);
919+
920+
#[cfg(test)]
921+
mod test {
922+
use compact_str::{CompactString, ToCompactString};
923+
use proptest::prelude::*;
924+
925+
use super::super::test::{check_collection_type, check_type};
926+
use super::*;
927+
928+
proptest! {
929+
#[test]
930+
fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
931+
let value = value.to_compact_string();
932+
super::test::check_type(value, tag, WireType::LengthDelimited,
933+
encode, merge, |tag, s| encoded_len(tag, &s.to_compact_string()))?;
934+
}
935+
#[test]
936+
fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
937+
let value = value.into_iter().map(CompactString::from).collect();
938+
super::test::check_collection_type(value, tag, WireType::LengthDelimited,
939+
encode_repeated, merge_repeated,
940+
encoded_len_repeated)?;
941+
}
942+
}
943+
}
944+
}
945+
876946
pub trait BytesAdapter: sealed::BytesAdapter {}
877947

878948
mod sealed {

0 commit comments

Comments
 (0)