Skip to content

Commit 3562a58

Browse files
committed
feat: derive Eq and Hash trait for messages where possible
Integer and bytes types can be compared using trait Eq. Some generated Rust structs can also have this property by deriving the Eq trait. Automatically derive Eq and Hash for: - messages that only have fields with integer or bytes types - messages where all field types also implement Eq and Hash - the Rust enum for one-of fields, where all fields implement Eq and Hash Generated code for Protobuf enums already derives Eq and Hash. BREAKING CHANGE: `prost-build` will automatically derive `trait Eq` and `trait Hash` for types where all field support those as well. If you manually `impl Eq` and/or `impl Hash` for generated types, then you need to remove the manual implementation. If you use `type_attribute` to `derive(Eq)` and/or `derive(Hash)`, then you need to remove those.
1 parent 86f87a2 commit 3562a58

File tree

10 files changed

+77
-45
lines changed

10 files changed

+77
-45
lines changed

prost-build/src/code_generator.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,17 @@ impl<'a> CodeGenerator<'a> {
229229
self.append_message_attributes(&fq_message_name);
230230
self.push_indent();
231231
self.buf.push_str(&format!(
232-
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
232+
"#[derive(Clone, {}PartialEq, {}{}::Message)]\n",
233233
if self.message_graph.can_message_derive_copy(&fq_message_name) {
234234
"Copy, "
235235
} else {
236236
""
237237
},
238+
if self.message_graph.can_message_derive_eq(&fq_message_name) {
239+
"Eq, Hash, "
240+
} else {
241+
""
242+
},
238243
prost_path(self.config)
239244
));
240245
self.append_skip_debug(&fq_message_name);
@@ -619,9 +624,14 @@ impl<'a> CodeGenerator<'a> {
619624
self.message_graph
620625
.can_field_derive_copy(fq_message_name, &field.descriptor)
621626
});
627+
let can_oneof_derive_eq = oneof.fields.iter().all(|field| {
628+
self.message_graph
629+
.can_field_derive_eq(fq_message_name, &field.descriptor)
630+
});
622631
self.buf.push_str(&format!(
623-
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
632+
"#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n",
624633
if can_oneof_derive_copy { "Copy, " } else { "" },
634+
if can_oneof_derive_eq { "Eq, Hash, " } else { "" },
625635
prost_path(self.config)
626636
));
627637
self.append_skip_debug(fq_message_name);
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
// This file is @generated by prost-build.
2-
#[derive(Clone, PartialEq, ::prost::Message)]
2+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
33
pub struct Container {
44
#[prost(oneof = "container::Data", tags = "1, 2")]
55
pub data: ::core::option::Option<container::Data>,
66
}
77
/// Nested message and enum types in `Container`.
88
pub mod container {
9-
#[derive(Clone, PartialEq, ::prost::Oneof)]
9+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
1010
pub enum Data {
1111
#[prost(message, tag = "1")]
1212
Foo(::prost::alloc::boxed::Box<super::Foo>),
1313
#[prost(message, tag = "2")]
1414
Bar(super::Bar),
1515
}
1616
}
17-
#[derive(Clone, PartialEq, ::prost::Message)]
17+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1818
pub struct Foo {
1919
#[prost(string, tag = "1")]
2020
pub foo: ::prost::alloc::string::String,
2121
}
22-
#[derive(Clone, PartialEq, ::prost::Message)]
22+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
2323
pub struct Bar {
2424
#[prost(message, optional, boxed, tag = "1")]
2525
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
2626
}
27-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
27+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
2828
pub struct Qux {}

prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// This file is @generated by prost-build.
22
#[derive(derive_builder::Builder)]
33
#[derive(custom_proto::Input)]
4-
#[derive(Clone, PartialEq, ::prost::Message)]
4+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
55
pub struct Message {
66
#[prost(string, tag = "1")]
77
pub say: ::prost::alloc::string::String,
88
}
99
#[derive(derive_builder::Builder)]
1010
#[derive(custom_proto::Output)]
11-
#[derive(Clone, PartialEq, ::prost::Message)]
11+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1212
pub struct Response {
1313
#[prost(string, tag = "1")]
1414
pub say: ::prost::alloc::string::String,

prost-build/src/message_graph.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,47 @@ impl MessageGraph {
153153
)
154154
}
155155
}
156+
157+
/// Returns `true` if this message can automatically derive Eq trait.
158+
pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool {
159+
assert_eq!(".", &fq_message_name[..1]);
160+
161+
let msg = self.messages.get(fq_message_name).unwrap();
162+
msg.field
163+
.iter()
164+
.all(|field| self.can_field_derive_eq(fq_message_name, field))
165+
}
166+
167+
/// Returns `true` if the type of this field allows deriving the Eq trait.
168+
pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
169+
assert_eq!(".", &fq_message_name[..1]);
170+
171+
if field.r#type() == Type::Message {
172+
if field.label() == Label::Repeated {
173+
false
174+
} else if self.is_nested(field.type_name(), fq_message_name) {
175+
false
176+
} else {
177+
self.can_message_derive_eq(field.type_name())
178+
}
179+
} else {
180+
matches!(
181+
field.r#type(),
182+
Type::Int32
183+
| Type::Int64
184+
| Type::Uint32
185+
| Type::Uint64
186+
| Type::Sint32
187+
| Type::Sint64
188+
| Type::Fixed32
189+
| Type::Fixed64
190+
| Type::Sfixed32
191+
| Type::Sfixed64
192+
| Type::Bool
193+
| Type::Enum
194+
| Type::String
195+
| Type::Bytes
196+
)
197+
}
198+
}
156199
}

prost-types/src/compiler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// This file is @generated by prost-build.
22
/// The version number of protocol compiler.
3-
#[derive(Clone, PartialEq, ::prost::Message)]
3+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
44
pub struct Version {
55
#[prost(int32, optional, tag = "1")]
66
pub major: ::core::option::Option<i32>,

prost-types/src/duration.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
use super::*;
22

3-
#[cfg(feature = "std")]
4-
impl std::hash::Hash for Duration {
5-
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
6-
self.seconds.hash(state);
7-
self.nanos.hash(state);
8-
}
9-
}
10-
113
impl Duration {
124
/// Normalizes the duration to a canonical format.
135
///

prost-types/src/protobuf.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ pub mod descriptor_proto {
8989
/// Range of reserved tag numbers. Reserved tag numbers may not be used by
9090
/// fields or extension ranges in the same message. Reserved ranges may
9191
/// not overlap.
92-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
92+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
9393
pub struct ReservedRange {
9494
/// Inclusive.
9595
#[prost(int32, optional, tag = "1")]
@@ -350,7 +350,7 @@ pub mod enum_descriptor_proto {
350350
/// Note that this is distinct from DescriptorProto.ReservedRange in that it
351351
/// is inclusive such that it can appropriately represent the entire int32
352352
/// domain.
353-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
353+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
354354
pub struct EnumReservedRange {
355355
/// Inclusive.
356356
#[prost(int32, optional, tag = "1")]
@@ -961,7 +961,7 @@ pub mod uninterpreted_option {
961961
/// extension (denoted with parentheses in options specs in .proto files).
962962
/// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents
963963
/// "foo.(bar.baz).qux".
964-
#[derive(Clone, PartialEq, ::prost::Message)]
964+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
965965
pub struct NamePart {
966966
#[prost(string, required, tag = "1")]
967967
pub name_part: ::prost::alloc::string::String,
@@ -1022,7 +1022,7 @@ pub struct SourceCodeInfo {
10221022
}
10231023
/// Nested message and enum types in `SourceCodeInfo`.
10241024
pub mod source_code_info {
1025-
#[derive(Clone, PartialEq, ::prost::Message)]
1025+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
10261026
pub struct Location {
10271027
/// Identifies which part of the FileDescriptorProto was defined at this
10281028
/// location.
@@ -1125,7 +1125,7 @@ pub struct GeneratedCodeInfo {
11251125
}
11261126
/// Nested message and enum types in `GeneratedCodeInfo`.
11271127
pub mod generated_code_info {
1128-
#[derive(Clone, PartialEq, ::prost::Message)]
1128+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
11291129
pub struct Annotation {
11301130
/// Identifies the element in the original source .proto file. This field
11311131
/// is formatted the same as SourceCodeInfo.Location.path.
@@ -1238,7 +1238,7 @@ pub mod generated_code_info {
12381238
/// "value": "1.212s"
12391239
/// }
12401240
/// ```
1241-
#[derive(Clone, PartialEq, ::prost::Message)]
1241+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
12421242
pub struct Any {
12431243
/// A URL/resource name that uniquely identifies the type of the serialized
12441244
/// protocol buffer message. This string must contain at least
@@ -1275,7 +1275,7 @@ pub struct Any {
12751275
}
12761276
/// `SourceContext` represents information about the source of a
12771277
/// protobuf element, like the file in which it is defined.
1278-
#[derive(Clone, PartialEq, ::prost::Message)]
1278+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
12791279
pub struct SourceContext {
12801280
/// The path-qualified name of the .proto file that contained the associated
12811281
/// protobuf element. For example: `"google/protobuf/source_context.proto"`.
@@ -1531,7 +1531,7 @@ pub struct EnumValue {
15311531
}
15321532
/// A protocol buffer option, which can be attached to a message, field,
15331533
/// enumeration, etc.
1534-
#[derive(Clone, PartialEq, ::prost::Message)]
1534+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
15351535
pub struct Option {
15361536
/// The option's name. For protobuf built-in options (options defined in
15371537
/// descriptor.proto), this is the short name. For example, `"map_entry"`.
@@ -1741,7 +1741,7 @@ pub struct Method {
17411741
/// ...
17421742
/// }
17431743
/// ```
1744-
#[derive(Clone, PartialEq, ::prost::Message)]
1744+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
17451745
pub struct Mixin {
17461746
/// The fully qualified name of the interface which is included.
17471747
#[prost(string, tag = "1")]
@@ -1815,7 +1815,7 @@ pub struct Mixin {
18151815
/// encoded in JSON format as "3s", while 3 seconds and 1 nanosecond should
18161816
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
18171817
/// microsecond should be expressed in JSON format as "3.000001s".
1818-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
1818+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
18191819
pub struct Duration {
18201820
/// Signed seconds of the span of time. Must be from -315,576,000,000
18211821
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
@@ -2053,7 +2053,7 @@ pub struct Duration {
20532053
/// The implementation of any API method which has a FieldMask type field in the
20542054
/// request should verify the included field paths, and return an
20552055
/// `INVALID_ARGUMENT` error if any path is unmappable.
2056-
#[derive(Clone, PartialEq, ::prost::Message)]
2056+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
20572057
pub struct FieldMask {
20582058
/// The set of field mask paths.
20592059
#[prost(string, repeated, tag = "1")]
@@ -2249,7 +2249,7 @@ impl NullValue {
22492249
/// [`strftime`](<https://docs.python.org/2/library/time.html#time.strftime>) with
22502250
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
22512251
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
2252-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
2252+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
22532253
pub struct Timestamp {
22542254
/// Represents seconds of UTC time since Unix epoch
22552255
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to

prost-types/src/timestamp.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,6 @@ impl Name for Timestamp {
123123
}
124124
}
125125

126-
/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`.
127-
/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`.
128-
#[cfg(feature = "std")]
129-
impl Eq for Timestamp {}
130-
131-
#[cfg(feature = "std")]
132-
impl std::hash::Hash for Timestamp {
133-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
134-
self.seconds.hash(state);
135-
self.nanos.hash(state);
136-
}
137-
}
138-
139126
#[cfg(feature = "std")]
140127
impl From<std::time::SystemTime> for Timestamp {
141128
fn from(system_time: std::time::SystemTime) -> Timestamp {

tests/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fn main() {
3838
config.type_attribute("Foo.Custom.Attrs.AnotherEnum", "/// Oneof docs");
3939
config.type_attribute(
4040
"Foo.Custom.OneOfAttrs.Msg.field",
41-
"#[derive(Eq, PartialOrd, Ord)]",
41+
"#[derive(PartialOrd, Ord)]",
4242
);
4343
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.C", "/// The C docs");
4444
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.D", "/// The D docs");

tests/single-include/src/outdir/outdir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// This file is @generated by prost-build.
2-
#[derive(Clone, PartialEq, ::prost::Message)]
2+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
33
pub struct OutdirRequest {
44
#[prost(string, tag = "1")]
55
pub query: ::prost::alloc::string::String,

0 commit comments

Comments
 (0)