From a86239b60a2acdf3b819efad3c0084f52616ca59 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Tue, 26 Mar 2024 18:44:55 +0200 Subject: [PATCH 1/8] prost: add OpenEnum The wrapper to represent field of enumerated types with the possibility of unknown values. --- prost/src/lib.rs | 2 + prost/src/open_enum.rs | 150 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 prost/src/open_enum.rs diff --git a/prost/src/lib.rs b/prost/src/lib.rs index efdfbc5c1..e3d33c501 100644 --- a/prost/src/lib.rs +++ b/prost/src/lib.rs @@ -12,6 +12,7 @@ pub use bytes; mod error; mod message; mod name; +mod open_enum; mod types; #[doc(hidden)] @@ -23,6 +24,7 @@ pub use crate::encoding::length_delimiter::{ pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue}; pub use crate::message::Message; pub use crate::name::Name; +pub use crate::open_enum::OpenEnum; // See `encoding::DecodeContext` for more info. // 100 is the default recursion limit in the C++ implementation. diff --git a/prost/src/open_enum.rs b/prost/src/open_enum.rs new file mode 100644 index 000000000..c36a1732e --- /dev/null +++ b/prost/src/open_enum.rs @@ -0,0 +1,150 @@ +use crate::encoding::{DecodeContext, WireType}; +use crate::{DecodeError, Message}; + +use bytes::{Buf, BufMut}; + +use core::fmt::Debug; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum OpenEnum { + Known(T), + Unknown(i32), +} + +impl Default for OpenEnum +where + T: Default, +{ + fn default() -> Self { + Self::Known(T::default()) + } +} + +impl From for OpenEnum { + fn from(value: T) -> Self { + Self::Known(value) + } +} + +impl OpenEnum { + pub fn from_raw(value: i32) -> Self + where + i32: TryInto, + { + match value.try_into() { + Ok(v) => Self::Known(v), + Err(_) => Self::Unknown(value), + } + } + + pub fn into_raw(self) -> i32 + where + T: Into, + { + match self { + Self::Known(v) => v.into(), + Self::Unknown(v) => v, + } + } + + pub fn to_raw(&self) -> i32 + where + T: Clone + Into, + { + match self { + Self::Known(v) => v.clone().into(), + Self::Unknown(v) => *v, + } + } +} + +impl OpenEnum { + pub fn unwrap(self) -> T { + match self { + Self::Known(v) => v, + Self::Unknown(v) => panic!("unknown field value {}", v), + } + } + + pub fn unwrap_or(self, default: T) -> T { + match self { + Self::Known(v) => v, + Self::Unknown(_) => default, + } + } + + pub fn unwrap_or_else(self, f: F) -> T + where + F: FnOnce(i32) -> T, + { + match self { + Self::Known(v) => v, + Self::Unknown(v) => f(v), + } + } + + pub fn unwrap_or_default(self) -> T + where + T: Default, + { + match self { + Self::Known(v) => v, + Self::Unknown(_) => T::default(), + } + } + + pub fn known(self) -> Option { + match self { + Self::Known(v) => Some(v), + Self::Unknown(_) => None, + } + } + + pub fn known_or(self, err: E) -> Result { + match self { + Self::Known(v) => Ok(v), + Self::Unknown(_) => Err(err), + } + } + + pub fn known_or_else(self, err: F) -> Result + where + F: FnOnce(i32) -> E, + { + match self { + Self::Known(v) => Ok(v), + Self::Unknown(v) => Err(err(v)), + } + } +} + +impl Message for OpenEnum +where + T: Clone + Into + Debug + Send + Sync, + i32: TryInto, +{ + fn encoded_len(&self) -> usize { + self.to_raw().encoded_len() + } + + fn encode_raw(&self, buf: &mut impl BufMut) { + self.to_raw().encode_raw(buf) + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + let mut raw = 0; + ::merge_field(&mut raw, tag, wire_type, buf, ctx)?; + *self = OpenEnum::from_raw(raw); + Ok(()) + } + + fn clear(&mut self) { + *self = OpenEnum::from_raw(0); + } +} From 72cd750b0388a19f1d2c9e3056a808ad31ad0568 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Fri, 29 Mar 2024 14:42:53 +0200 Subject: [PATCH 2/8] prost: encoding module for OpenEnum --- prost/src/encoding.rs | 122 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index e12455574..de7574aa5 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -519,6 +519,128 @@ fixed_width!( get_i64_le ); +pub mod enumeration { + use super::*; + use crate::OpenEnum; + + pub fn encode(tag: u32, value: &OpenEnum, buf: &mut B) + where + T: Clone + Into, + B: BufMut, + { + int32::encode(tag, &value.to_raw(), buf) + } + + pub fn merge( + wire_type: WireType, + value: &mut OpenEnum, + buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + i32: TryInto, + B: Buf, + { + let mut raw = 0; + int32::merge(wire_type, &mut raw, buf, DecodeContext::default())?; + *value = OpenEnum::from_raw(raw); + Ok(()) + } + + pub fn encode_repeated(tag: u32, values: &[OpenEnum], buf: &mut B) + where + T: Clone + Into, + B: BufMut, + { + for value in values { + encode(tag, value, buf); + } + } + + pub fn encode_packed(tag: u32, values: &[OpenEnum], buf: &mut B) + where + T: Clone + Into, + B: BufMut, + { + if values.is_empty() { + return; + } + + encode_key(tag, WireType::LengthDelimited, buf); + let len: usize = values + .iter() + .map(|value| encoded_len_varint(value.to_raw() as u64)) + .sum(); + encode_varint(len as u64, buf); + + for value in values { + encode_varint(value.to_raw() as u64, buf); + } + } + + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + T: Default, + i32: TryInto, + B: Buf, + { + if wire_type == WireType::LengthDelimited { + // Packed. + merge_loop(values, buf, ctx, |values, buf, ctx| { + let mut value = Default::default(); + merge(WireType::Varint, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + }) + } else { + // Unpacked. + check_wire_type(WireType::Varint, wire_type)?; + let mut value = Default::default(); + merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + } + + pub fn encoded_len(tag: u32, value: &OpenEnum) -> usize + where + T: Clone + Into, + { + int32::encoded_len(tag, &value.to_raw()) + } + + pub fn encoded_len_repeated(tag: u32, values: &[OpenEnum]) -> usize + where + T: Clone + Into, + { + key_len(tag) * values.len() + + values + .iter() + .map(|value| encoded_len_varint(value.to_raw() as u64)) + .sum::() + } + + pub fn encoded_len_packed(tag: u32, values: &[OpenEnum]) -> usize + where + T: Clone + Into, + { + if values.is_empty() { + 0 + } else { + let len = values + .iter() + .map(|value| encoded_len_varint(value.to_raw() as u64)) + .sum::(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + } +} + /// Macro which emits encoding functions for a length-delimited type. macro_rules! length_delimited { ($ty:ty) => { From 29da76b66c15bfa893cf62c6dbcfccf08298853e Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Fri, 29 Mar 2024 15:07:05 +0200 Subject: [PATCH 3/8] Use OpenEnum in generated code Replace i32 with the type-checked wrapper. --- prost-build/src/code_generator.rs | 11 ++++++--- prost-derive/src/field/map.rs | 24 +++++++----------- prost-derive/src/field/scalar.rs | 41 ++++++++++++++----------------- prost-types/src/protobuf.rs | 30 ++++++++++++---------- tests/src/custom_debug.rs | 14 +++++------ tests/src/debug.rs | 22 +++++++++++------ tests/src/lib.rs | 2 +- tests/src/message_encoding.rs | 23 ++++++++--------- 8 files changed, 86 insertions(+), 81 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index f8d341445..4014dcb99 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -402,7 +402,7 @@ impl CodeGenerator<'_> { fn append_field(&mut self, fq_message_name: &str, field: &Field) { let type_ = field.descriptor.r#type(); - let repeated = field.descriptor.label == Some(Label::Repeated as i32); + let repeated = field.descriptor.label.and_then(|v| v.known()) == Some(Label::Repeated); let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self.boxed(&field.descriptor, fq_message_name, None); @@ -953,7 +953,7 @@ impl CodeGenerator<'_> { Type::Double => String::from("f64"), Type::Uint32 | Type::Fixed32 => String::from("u32"), Type::Uint64 | Type::Fixed64 => String::from("u64"), - Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), + Type::Int32 | Type::Sfixed32 | Type::Sint32 => String::from("i32"), Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), Type::Bool => String::from("bool"), Type::String => format!("{}::alloc::string::String", prost_path(self.config)), @@ -966,6 +966,11 @@ impl CodeGenerator<'_> { .rust_type() .to_owned(), Type::Group | Type::Message => self.resolve_ident(field.type_name()), + Type::Enum => format!( + "{}::OpenEnum<{}>", + prost_path(self.config), + self.resolve_ident(field.type_name()) + ), } } @@ -1069,7 +1074,7 @@ impl CodeGenerator<'_> { fq_message_name: &str, oneof: Option<&str>, ) -> bool { - let repeated = field.label == Some(Label::Repeated as i32); + let repeated = field.label.and_then(|v| v.known()) == Some(Label::Repeated); let fd_type = field.r#type(); if !repeated && (fd_type == Type::Message || fd_type == Type::Group) diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..a20a586ad 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -132,13 +132,13 @@ impl Field { let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(#ty::default() as i32); + let default = quote!(::prost::OpenEnum::from(#ty::default())); quote! { ::prost::encoding::#module::encode_with_default( #ke, #kl, - ::prost::encoding::int32::encode, - ::prost::encoding::int32::encoded_len, + ::prost::encoding::enumeration::encode, + ::prost::encoding::enumeration::encoded_len, &(#default), #tag, &#ident, @@ -184,11 +184,11 @@ impl Field { let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(#ty::default() as i32); + let default = quote!(::prost::OpenEnum::from(#ty::default())); quote! { ::prost::encoding::#module::merge_with_default( #km, - ::prost::encoding::int32::merge, + ::prost::encoding::enumeration::merge, #default, &mut #ident, buf, @@ -221,11 +221,11 @@ impl Field { let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { - let default = quote!(#ty::default() as i32); + let default = quote!(::prost::OpenEnum::from(#ty::default())); quote! { ::prost::encoding::#module::encoded_len_with_default( #kl, - ::prost::encoding::int32::encoded_len, + ::prost::encoding::enumeration::encoded_len, &(#default), #tag, &#ident, @@ -275,17 +275,11 @@ impl Field { Some(quote! { #[doc=#get_doc] pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { - self.#ident.get(#take_ref key).cloned().and_then(|x| { - let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); - result.ok() - }) + self.#ident.get(#take_ref key).cloned().and_then(|x| { x.known() }) } #[doc=#insert_doc] pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { - self.#ident.insert(key, value as i32).and_then(|x| { - let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); - result.ok() - }) + self.#ident.insert(key, value.into()).and_then(|x| { x.known() }) } }) } else { diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index c2e870524..62a9d4fd7 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -216,13 +216,12 @@ impl Field { fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { if let Ty::Enumeration(ref ty) = self.ty { quote! { - struct #wrap_name<'a>(&'a i32); + struct #wrap_name<'a>(&'a ::prost::OpenEnum<#ty>); impl<'a> ::core::fmt::Debug for #wrap_name<'a> { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { - let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0); - match res { - Err(_) => ::core::fmt::Debug::fmt(&self.0, f), - Ok(en) => ::core::fmt::Debug::fmt(&en, f), + match self.0.known() { + Some(en) => ::core::fmt::Debug::fmt(&en, f), + None => ::core::fmt::Debug::fmt(&self.0, f), } } } @@ -297,12 +296,12 @@ impl Field { quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default) + self.#ident.unwrap_or(#default) } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = value as i32; + self.#ident = value.into(); } } } @@ -315,15 +314,12 @@ impl Field { quote! { #[doc=#get_doc] pub fn #get(&self) -> #ty { - self.#ident.and_then(|x| { - let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); - result.ok() - }).unwrap_or(#default) + self.#ident.and_then(|x| { x.known() }).unwrap_or(#default) } #[doc=#set_doc] pub fn #set(&mut self, value: #ty) { - self.#ident = ::core::option::Option::Some(value as i32); + self.#ident = ::core::option::Option::Some(value.into()); } } } @@ -334,20 +330,18 @@ impl Field { ); let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); + let wrapped_ty = quote!(::prost::OpenEnum<#ty>); quote! { #[doc=#iter_doc] pub fn #get(&self) -> ::core::iter::FilterMap< - ::core::iter::Cloned<::core::slice::Iter>, - fn(i32) -> ::core::option::Option<#ty>, + ::core::iter::Cloned<::core::slice::Iter<#wrapped_ty>>, + fn(#wrapped_ty) -> ::core::option::Option<#ty>, > { - self.#ident.iter().cloned().filter_map(|x| { - let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); - result.ok() - }) + self.#ident.iter().cloned().filter_map(|x| { x.known() }) } #[doc=#push_doc] pub fn #push(&mut self, value: #ty) { - self.#ident.push(value as i32); + self.#ident.push(value.into()); } } } @@ -533,13 +527,14 @@ impl Ty { match self { Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), + Ty::Enumeration(path) => quote!(::prost::OpenEnum<#path>), _ => self.rust_ref_type(), } } // TODO: rename to 'ref_type' pub fn rust_ref_type(&self) -> TokenStream { - match *self { + match self { Ty::Double => quote!(f64), Ty::Float => quote!(f32), Ty::Int32 => quote!(i32), @@ -555,13 +550,13 @@ impl Ty { Ty::Bool => quote!(bool), Ty::String => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), - Ty::Enumeration(..) => quote!(i32), + Ty::Enumeration(..) => unreachable!("an enum should never be queried for its ref type"), } } pub fn module(&self) -> Ident { match *self { - Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), + Ty::Enumeration(..) => Ident::new("enumeration", Span::call_site()), _ => Ident::new(self.as_str(), Span::call_site()), } } @@ -799,7 +794,7 @@ impl DefaultValue { pub fn typed(&self) -> TokenStream { if let DefaultValue::Enumeration(_) = *self { - quote!(#self as i32) + quote!(::prost::OpenEnum::from(#self)) } else { quote!(#self) } diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 6f75dfc2b..6c2b618a0 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -113,11 +113,11 @@ pub struct FieldDescriptorProto { #[prost(int32, optional, tag = "3")] pub number: ::core::option::Option, #[prost(enumeration = "field_descriptor_proto::Label", optional, tag = "4")] - pub label: ::core::option::Option, + pub label: ::core::option::Option<::prost::OpenEnum>, /// If type_name is set, this need not be set. If both this and type_name /// are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. #[prost(enumeration = "field_descriptor_proto::Type", optional, tag = "5")] - pub r#type: ::core::option::Option, + pub r#type: ::core::option::Option<::prost::OpenEnum>, /// For message and enum types, this is the name of the type. If the name /// starts with a '.', it is fully-qualified. Otherwise, C++-like scoping /// rules are used to find the type (i.e. first the nested types within this @@ -470,7 +470,9 @@ pub struct FileOptions { tag = "9", default = "Speed" )] - pub optimize_for: ::core::option::Option, + pub optimize_for: ::core::option::Option< + ::prost::OpenEnum, + >, /// Sets the Go package where structs generated from this .proto will be /// placed. If omitted, the Go package will be derived from the following: /// @@ -664,7 +666,7 @@ pub struct FieldOptions { tag = "1", default = "String" )] - pub ctype: ::core::option::Option, + pub ctype: ::core::option::Option<::prost::OpenEnum>, /// The packed option can be enabled for repeated primitive fields to enable /// a more efficient representation on the wire. Rather than repeatedly /// writing the tag and type for each element, the entire array is encoded as @@ -689,7 +691,7 @@ pub struct FieldOptions { tag = "6", default = "JsNormal" )] - pub jstype: ::core::option::Option, + pub jstype: ::core::option::Option<::prost::OpenEnum>, /// Should this field be parsed lazily? Lazy applies only to message-type /// fields. It means that when the outer message is initially parsed, the /// inner message's contents will not be parsed but instead stored in encoded @@ -877,7 +879,9 @@ pub struct MethodOptions { tag = "34", default = "IdempotencyUnknown" )] - pub idempotency_level: ::core::option::Option, + pub idempotency_level: ::core::option::Option< + ::prost::OpenEnum, + >, /// The parser stores options it doesn't recognize here. See above. #[prost(message, repeated, tag = "999")] pub uninterpreted_option: ::prost::alloc::vec::Vec, @@ -1302,17 +1306,17 @@ pub struct Type { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "6")] - pub syntax: i32, + pub syntax: ::prost::OpenEnum, } /// A single field of a message type. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { /// The field type. #[prost(enumeration = "field::Kind", tag = "1")] - pub kind: i32, + pub kind: ::prost::OpenEnum, /// The field cardinality. #[prost(enumeration = "field::Cardinality", tag = "2")] - pub cardinality: i32, + pub cardinality: ::prost::OpenEnum, /// The field number. #[prost(int32, tag = "3")] pub number: i32, @@ -1514,7 +1518,7 @@ pub struct Enum { pub source_context: ::core::option::Option, /// The source syntax. #[prost(enumeration = "Syntax", tag = "5")] - pub syntax: i32, + pub syntax: ::prost::OpenEnum, } /// Enum value definition. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1626,7 +1630,7 @@ pub struct Api { pub mixins: ::prost::alloc::vec::Vec, /// The source syntax of the service. #[prost(enumeration = "Syntax", tag = "7")] - pub syntax: i32, + pub syntax: ::prost::OpenEnum, } /// Method represents a method of an API interface. #[derive(Clone, PartialEq, ::prost::Message)] @@ -1651,7 +1655,7 @@ pub struct Method { pub options: ::prost::alloc::vec::Vec