Skip to content

Commit

Permalink
feature(prost-build): Generate less boxed if nested type is boxed man…
Browse files Browse the repository at this point in the history
…ually
  • Loading branch information
ldm0 committed Sep 20, 2024
1 parent fb977f4 commit 13f0b78
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 14 deletions.
2 changes: 1 addition & 1 deletion prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ impl<'a> CodeGenerator<'a> {
&& (fd_type == Type::Message || fd_type == Type::Group)
&& self
.message_graph
.is_nested(field.type_name(), fq_message_name)
.is_directly_nested(field.type_name(), fq_message_name)
{
return true;
}
Expand Down
142 changes: 129 additions & 13 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;
use petgraph::visit::{EdgeRef, VisitMap};
use petgraph::{Direction, Graph};

use prost_types::{
field_descriptor_proto::{Label, Type},
Expand All @@ -15,9 +15,13 @@ use crate::path::PathMap;
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
/// Map<fq type name, graph node index>
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
/// Graph with fq type name as node, field name as edge
graph: Graph<String, String>,
/// Map<fq type name, DescriptorProto>
messages: HashMap<String, DescriptorProto>,
/// Manually boxed fields
boxed: PathMap<()>,
}

Expand Down Expand Up @@ -71,7 +75,8 @@ impl MessageGraph {
for field in &msg.field {
if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
self.graph
.add_edge(msg_index, field_index, field.name.clone().unwrap());
}
}
self.messages.insert(msg_name.clone(), msg.clone());
Expand All @@ -86,8 +91,9 @@ impl MessageGraph {
self.messages.get(message)
}

/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
/// Returns true if message type `inner` is nested in message type `outer`,
/// and no field edge in the chain of dependencies is manually boxed.
pub fn is_directly_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {
Some(outer) => *outer,
None => return false,
Expand All @@ -97,7 +103,12 @@ impl MessageGraph {
None => return false,
};

has_path_connecting(&self.graph, outer, inner, None)
// Check if `inner` is nested in `outer` and ensure that all edge fields are not boxed manually.
is_connected_with_edge_filter(&self.graph, outer, inner, |node, field_name| {
self.boxed
.get_first_field(&self.graph[node], field_name)
.is_none()
})
}

/// Returns `true` if this message can automatically derive Copy trait.
Expand All @@ -123,11 +134,11 @@ impl MessageGraph {
false
} else if field.r#type() == Type::Message {
// nested and boxed messages cannot derive Copy
if self.is_nested(field.type_name(), fq_message_name)
|| self
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
if self
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
|| self.is_directly_nested(field.type_name(), fq_message_name)
{
false
} else {
Expand All @@ -154,3 +165,108 @@ impl MessageGraph {
}
}
}

/// Check two nodes is connected with edge filter
fn is_connected_with_edge_filter<F, N, E>(
graph: &Graph<N, E>,
start: NodeIndex,
end: NodeIndex,
mut is_good_edge: F,
) -> bool
where
F: FnMut(NodeIndex, &E) -> bool,
{
fn visitor<F, N, E>(
graph: &Graph<N, E>,
start: NodeIndex,
end: NodeIndex,
is_good_edge: &mut F,
visited: &mut HashSet<NodeIndex>,
) -> bool
where
F: FnMut(NodeIndex, &E) -> bool,
{
if start == end {
return true;
}
visited.visit(start);
for edge in graph.edges_directed(start, Direction::Outgoing) {
// if the edge doesn't pass the filter, skip it
if !is_good_edge(start, edge.weight()) {
continue;
}
let target = edge.target();
if visited.is_visited(&target) {
continue;
}
if visitor(graph, target, end, is_good_edge, visited) {
return true;
}
}
false
}
let mut visited = HashSet::new();
visitor(graph, start, end, &mut is_good_edge, &mut visited)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_connected() {
let mut graph = Graph::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
let n5 = graph.add_node(5);
let n6 = graph.add_node(6);
let n7 = graph.add_node(7);
let n8 = graph.add_node(8);
graph.add_edge(n1, n2, 1.);
graph.add_edge(n2, n3, 2.);
graph.add_edge(n3, n4, 3.);
graph.add_edge(n4, n5, 4.);
graph.add_edge(n5, n6, 5.);
graph.add_edge(n6, n7, 6.);
graph.add_edge(n7, n8, 7.);
graph.add_edge(n8, n1, 8.);
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
true
}),);
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
edge < &8.5
}),);
assert!(!is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
edge < &7.5
}),);
}

#[test]
fn test_connected_multi_circle() {
let mut graph = Graph::new();
let n0 = graph.add_node(0);
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n0, n1, 0.);
graph.add_edge(n1, n2, 1.);
graph.add_edge(n2, n3, 2.);
graph.add_edge(n3, n0, 3.);
graph.add_edge(n1, n4, 1.5);
graph.add_edge(n4, n0, 2.5);
assert!(is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
dbg!(edge);
edge < &2.8
}),);
assert!(!is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
dbg!(edge);
edge < &2.1
}),);
}
}
13 changes: 13 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@ fn main() {

std::fs::create_dir_all(&out_path).unwrap();

prost_build::Config::new()
.out_dir(src.join("nesting_complex/boxed"))
.boxed("Foo.bar")
.boxed("BazB.baz_c")
.boxed("BakC.bak_d")
.compile_protos(&[src.join("nesting_complex.proto")], includes)
.unwrap();

prost_build::Config::new()
.out_dir(src.join("nesting_complex/"))
.compile_protos(&[src.join("nesting_complex.proto")], includes)
.unwrap();

prost_build::Config::new()
.bytes(["."])
.out_dir(out_path)
Expand Down
8 changes: 8 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ pub mod proto3 {
}
}

pub mod nesting_complex_boxed {
include!("nesting_complex/boxed/nesting_complex.rs");
}

pub mod nesting_complex {
include!("nesting_complex/nesting_complex.rs");
}

pub mod invalid {
pub mod doctest {
include!(concat!(env!("OUT_DIR"), "/invalid.doctest.rs"));
Expand Down
47 changes: 47 additions & 0 deletions tests/src/nesting_complex.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
syntax = "proto2";

package nesting_complex;

// ----- Directly nested
message Foo {
optional Bar bar = 1;
}

message Bar {
optional Foo foo = 1;
}

// ----- Transitively nested
message BazA {
optional BazB baz_b = 1;
}

message BazB {
optional BazC baz_c = 1;
}

message BazC {
optional BazA baz_a = 1;
}

// ----- Transitively nested in two chain
message BakA {
optional BakB bak_b = 1;
}

message BakB {
optional BakC bak_c = 1;
optional BakE bak_e = 2;
}

message BakC {
optional BakD bak_d = 1;
}

message BakD {
optional BakA bak_a = 1;
}

message BakE {
optional BakA bak_a = 1;
}
56 changes: 56 additions & 0 deletions tests/src/nesting_complex/boxed/nesting_complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// This file is @generated by prost-build.
/// ----- Directly nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Foo {
#[prost(message, optional, boxed, tag = "1")]
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, tag = "1")]
pub foo: ::core::option::Option<Foo>,
}
/// ----- Transitively nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazA {
#[prost(message, optional, tag = "1")]
pub baz_b: ::core::option::Option<BazB>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazB {
#[prost(message, optional, boxed, tag = "1")]
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazC {
#[prost(message, optional, tag = "1")]
pub baz_a: ::core::option::Option<BazA>,
}
/// ----- Transitively nested in two chain
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakA {
#[prost(message, optional, boxed, tag = "1")]
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakB {
#[prost(message, optional, tag = "1")]
pub bak_c: ::core::option::Option<BakC>,
#[prost(message, optional, boxed, tag = "2")]
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakC {
#[prost(message, optional, boxed, tag = "1")]
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakD {
#[prost(message, optional, tag = "1")]
pub bak_a: ::core::option::Option<BakA>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakE {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}
56 changes: 56 additions & 0 deletions tests/src/nesting_complex/nesting_complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// This file is @generated by prost-build.
/// ----- Directly nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Foo {
#[prost(message, optional, boxed, tag = "1")]
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub foo: ::core::option::Option<::prost::alloc::boxed::Box<Foo>>,
}
/// ----- Transitively nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazA {
#[prost(message, optional, boxed, tag = "1")]
pub baz_b: ::core::option::Option<::prost::alloc::boxed::Box<BazB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazB {
#[prost(message, optional, boxed, tag = "1")]
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazC {
#[prost(message, optional, boxed, tag = "1")]
pub baz_a: ::core::option::Option<::prost::alloc::boxed::Box<BazA>>,
}
/// ----- Transitively nested in two chain
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakA {
#[prost(message, optional, boxed, tag = "1")]
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakB {
#[prost(message, optional, boxed, tag = "1")]
pub bak_c: ::core::option::Option<::prost::alloc::boxed::Box<BakC>>,
#[prost(message, optional, boxed, tag = "2")]
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakC {
#[prost(message, optional, boxed, tag = "1")]
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakD {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakE {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}

0 comments on commit 13f0b78

Please sign in to comment.