diff --git a/src/nunavut/lang/_config.py b/src/nunavut/lang/_config.py index 12bab95d..a4671d80 100644 --- a/src/nunavut/lang/_config.py +++ b/src/nunavut/lang/_config.py @@ -34,7 +34,7 @@ def parse_string(s: str) -> typing.Optional[typing.Any]: # annoying mypy cheat class SpecialMethod(Enum): """ - Enum used in the Jinja templates to differentiate different kinds of constructrors + Enum used in the Jinja templates to differentiate different kinds of constructors """ AllocatorConstructor = auto() @@ -50,6 +50,18 @@ class SpecialMethod(Enum): """ Move constructor that also takes an allocator argument """ +class CompositeSubType(Enum): + """ + Enum used in the Jinja templates to designate how fields are contained in a composite type + """ + + Structure = auto() + """ Object contains a set of sequential fields """ + + Union = auto() + """ Object contains one field which may hold any value from a set of types """ + + class LanguageConfig: """ Configuration storage encapsulating parsers and other configuration format details. For any configuration type used diff --git a/src/nunavut/lang/cpp/__init__.py b/src/nunavut/lang/cpp/__init__.py index 547e6d0d..923985bf 100644 --- a/src/nunavut/lang/cpp/__init__.py +++ b/src/nunavut/lang/cpp/__init__.py @@ -28,7 +28,7 @@ from nunavut._utilities import YesNoDefault from nunavut.jinja.environment import Environment from nunavut.lang._common import IncludeGenerator, TokenEncoder, UniqueNameGenerator -from nunavut.lang._config import ConstructorConvention, SpecialMethod +from nunavut.lang._config import ConstructorConvention, SpecialMethod, CompositeSubType from nunavut.lang._language import Language as BaseLanguage from nunavut.lang.c import _CFit from nunavut.lang.c import filter_literal as c_filter_literal @@ -155,6 +155,7 @@ def _add_additional_globals(self, globals_map: typing.Dict[str, typing.Any]) -> """ globals_map["ConstructorConvention"] = ConstructorConvention globals_map["SpecialMethod"] = SpecialMethod + globals_map["CompositeSubType"] = CompositeSubType def get_includes(self, dep_types: Dependencies) -> typing.List[str]: """ @@ -967,7 +968,7 @@ def needs_rhs(special_method: SpecialMethod) -> bool: def needs_allocator(instance: pydsdl.Any) -> bool: - """Helper method used by filter_value_initializer()""" + """Helper method used by filter_value_initializer() and filter_needs_allocator()""" return isinstance(instance.data_type, pydsdl.VariableLengthArrayType) or isinstance( instance.data_type, pydsdl.CompositeType ) @@ -980,6 +981,10 @@ def needs_vla_init_args(instance: pydsdl.Any, special_method: SpecialMethod) -> ) +def needs_variant_init_args(composite_subtype: CompositeSubType) -> bool: + return composite_subtype == CompositeSubType.Union + + def needs_move(special_method: SpecialMethod) -> bool: """Helper method used by filter_value_initializer()""" return special_method == SpecialMethod.MoveConstructorWithAllocator @@ -994,12 +999,45 @@ def requires_initialization(instance: pydsdl.Any) -> bool: ) -def assemble_initializer_expression( - wrap: str, rhs: str, leading_args: typing.List[str], trailing_args: typing.List[str] -) -> str: +def prepare_initializer_args( + language: Language, + instance: pydsdl.Any, + special_method: SpecialMethod, + composite_subtype: CompositeSubType, +) -> typing.Tuple[typing.List[str], str, typing.List[str]]: + rhs: str = "" + leading_args: typing.List[str] = [] + trailing_args: typing.List[str] = [] + if needs_variant_init_args(composite_subtype): + leading_args.append( + f"nunavut::support::in_place_index_t{{}}" + ) + + if needs_initializing_value(special_method): + instance_id = language.filter_id(instance) + if needs_rhs(special_method): + rhs = "rhs." + rhs += f"get_{instance_id}()" if composite_subtype is CompositeSubType.Union else instance_id + + if needs_vla_init_args(instance, special_method): + constructor_args = language.get_option("variable_array_type_constructor_args") + if isinstance(constructor_args, str) and len(constructor_args) > 0: + trailing_args.append(constructor_args.format(MAX_SIZE=instance.data_type.capacity)) + + if needs_allocator(instance): + if language.get_option("ctor_convention") == ConstructorConvention.UsesLeadingAllocator.value: + leading_args.extend(["std::allocator_arg", "allocator"]) + else: + trailing_args.append("allocator") + + if needs_move(special_method): + rhs = "std::move({})".format(rhs) + + return (leading_args, rhs, trailing_args) + + +def assemble_initializer_expression(rhs: str, leading_args: typing.List[str], trailing_args: typing.List[str]) -> str: """Helper method used by filter_value_initializer()""" - if wrap: - rhs = "{}({})".format(wrap, rhs) args = [] if rhs: args.append(rhs) @@ -1008,40 +1046,34 @@ def assemble_initializer_expression( @template_language_filter(__name__) -def filter_value_initializer(language: Language, instance: pydsdl.Any, special_method: SpecialMethod) -> str: +def filter_value_initializer( + language: Language, + instance: pydsdl.Any, + special_method: SpecialMethod, + composite_subtype: CompositeSubType = CompositeSubType.Structure, +) -> str: """ Emit an initialization expression for a C++ special method. """ value_initializer: str = "" if requires_initialization(instance): - wrap: str = "" rhs: str = "" leading_args: typing.List[str] = [] trailing_args: typing.List[str] = [] - if needs_initializing_value(special_method): - if needs_rhs(special_method): - rhs = "rhs." - rhs += language.filter_id(instance) - - if needs_vla_init_args(instance, special_method): - constructor_args = language.get_option("variable_array_type_constructor_args") - if isinstance(constructor_args, str) and len(constructor_args) > 0: - trailing_args.append(constructor_args.format(MAX_SIZE=instance.data_type.capacity)) - - if needs_allocator(instance): - if language.get_option("ctor_convention") == ConstructorConvention.UsesLeadingAllocator.value: - leading_args.extend(["std::allocator_arg", "allocator"]) - else: - trailing_args.append("allocator") + leading_args, rhs, trailing_args = prepare_initializer_args( + language, instance, special_method, composite_subtype + ) + value_initializer = assemble_initializer_expression(rhs, leading_args, trailing_args) - if needs_move(special_method): - wrap = "std::move" + return value_initializer - value_initializer = assemble_initializer_expression(wrap, rhs, leading_args, trailing_args) - return value_initializer +@template_language_test(__name__) +def filter_needs_allocator(language: Language, instance: pydsdl.Any) -> bool: + """Emit a boolean value for whether the instance's type needs an allocator or not""" + return needs_allocator(instance) @template_language_filter(__name__) diff --git a/src/nunavut/lang/cpp/support/utility.j2 b/src/nunavut/lang/cpp/support/utility.j2 new file mode 100644 index 00000000..565c7937 --- /dev/null +++ b/src/nunavut/lang/cpp/support/utility.j2 @@ -0,0 +1,39 @@ +// OpenCyphal common union composite type support routines +// +// AUTOGENERATED, DO NOT EDIT. +// +//--------------------------------------------------------------------------------------------------------------------- +// Language Options +{% for key, value in options.items() -%} +// {{ key }}: {{ value }} +{% endfor %} + +#ifndef NUNAVUT_SUPPORT_UTILITIES_HPP_INCLUDED +#define NUNAVUT_SUPPORT_UTILITIES_HPP_INCLUDED + +{% ifuses "std_variant" -%} +#include +{%- else -%} +#include +{%- endifuses %} + +namespace nunavut +{ +namespace support +{ + +// Value-specialized type for template instantiation +template +{% ifuses "std_variant" -%} +using in_place_index_t = std::in_place_index_t; +{%- else -%} +struct in_place_index_t +{ + explicit in_place_index_t() = default; +}; +{%- endifuses %} + +} // end namespace support +} // end namespace nunavut + +#endif // NUNAVUT_SUPPORT_UTILITIES_HPP_INCLUDED diff --git a/src/nunavut/lang/cpp/templates/_composite_type.j2 b/src/nunavut/lang/cpp/templates/_composite_type.j2 index 8f0bd60c..59901bd3 100644 --- a/src/nunavut/lang/cpp/templates/_composite_type.j2 +++ b/src/nunavut/lang/cpp/templates/_composite_type.j2 @@ -84,6 +84,14 @@ struct {% if composite_type.deprecated -%} {%- endif %} {%- endfor %} }; +{% if composite_type.inner_type is UnionType -%} +{%- ifuses "std_variant" -%} +{% include '_fields_as_variant.j2' %} +{%- else -%} +{% include '_fields_as_union.j2' %} +{%- endifuses -%} +{%- endif -%} + {% if options.ctor_convention != ConstructorConvention.Default.value %} {%- if options.allocator_is_default_constructible %} // Default constructor @@ -102,19 +110,26 @@ struct {% if composite_type.deprecated -%} // Allocator constructor explicit {{composite_type|short_reference_name}}(const allocator_type& allocator) - {%- if composite_type.fields_except_padding %} :{% endif %} + {%- if composite_type.fields_except_padding %} : {%- if composite_type.inner_type is UnionType %} - union_value{} // can't make use of the allocator with a union + union_value{{ composite_type.fields_except_padding[0] | value_initializer(SpecialMethod.AllocatorConstructor, CompositeSubType.Union) }} {%- else %} {%- for field in composite_type.fields_except_padding %} {{ field | id }}{{ field | value_initializer(SpecialMethod.AllocatorConstructor) }}{%if not loop.last %},{%endif %} {%- endfor %} {%- endif %} + {%- endif %} { (void)allocator; // avoid unused param warning } - {%- if composite_type.inner_type is not UnionType %} + {% if composite_type.inner_type is UnionType -%} + // Initializing constructor + template + {{composite_type|short_reference_name}}(nunavut::support::in_place_index_t i, Args&&... args) + : union_value{i, std::forward(args)...} + {} + {%- else %} {% if composite_type.fields_except_padding %} // Initializing constructor {{ composite_type | explicit_decorator(SpecialMethod.InitializingConstructorWithAllocator)}}( @@ -140,15 +155,26 @@ struct {% if composite_type.deprecated -%} {{composite_type|short_reference_name}}(const {{composite_type|short_reference_name}}& rhs, const allocator_type& allocator) {%- if composite_type.fields_except_padding %} :{% endif %} {%- if composite_type.inner_type is UnionType %} - union_value{rhs.union_value} // can't make use of the allocator with a union + union_value{std::move( + {%- set ns = namespace(indent = "") %} + {%- for field in composite_type.fields_except_padding %} + {%- if not loop.last %} + {{ns.indent}}rhs.is_{{ field | id }}() ? + {%- set ns.indent = ns.indent ~ "\t" %} + {%- endif %} + {{ns.indent}}VariantType + {{- field | value_initializer(SpecialMethod.CopyConstructorWithAllocator, CompositeSubType.Union) }} + {%- if not loop.last %} :{% endif %} + {%- endfor %} + )} {%- else %} {%- for field in composite_type.fields_except_padding %} {{ field | id }}{{ field | value_initializer(SpecialMethod.CopyConstructorWithAllocator) }}{%if not loop.last %},{%endif %} {%- endfor %} - {% endif %} + {%- endif %} { - (void)rhs; // avoid unused param warning - (void)allocator; // avoid unused param warning + (void)rhs; // avoid unused param warning + (void)allocator; // avoid unused param warning } // Move constructor @@ -158,15 +184,26 @@ struct {% if composite_type.deprecated -%} {{composite_type|short_reference_name}}({{composite_type|short_reference_name}}&& rhs, const allocator_type& allocator) {%- if composite_type.fields_except_padding %} :{% endif %} {%- if composite_type.inner_type is UnionType %} - union_value{} // can't make use of the allocator with a union + union_value{std::move( + {%- set ns = namespace(indent = "") %} + {%- for field in composite_type.fields_except_padding %} + {%- if not loop.last %} + {{ns.indent}}rhs.is_{{ field | id }}() ? + {%- set ns.indent = ns.indent ~ "\t" %} + {%- endif %} + {{ns.indent}}VariantType + {{- field | value_initializer(SpecialMethod.MoveConstructorWithAllocator, CompositeSubType.Union) }} + {%- if not loop.last %} :{% endif %} + {%- endfor %} + )} {%- else %} {%- for field in composite_type.fields_except_padding %} {{ field | id }}{{ field | value_initializer(SpecialMethod.MoveConstructorWithAllocator) }}{%if not loop.last %},{%endif %} {%- endfor %} {%- endif %} { - (void)rhs; // avoid unused param warning - (void)allocator; // avoid unused param warning + (void)rhs; // avoid unused param warning + (void)allocator; // avoid unused param warning } // Copy assignment @@ -177,7 +214,7 @@ struct {% if composite_type.deprecated -%} // Destructor ~{{composite_type|short_reference_name}}() = default; -{%- endif %} +{% endif %} {%- for constant in composite_type.constants %} {% if loop.first %} @@ -189,11 +226,6 @@ struct {% if composite_type.deprecated -%} static constexpr {{ constant.data_type | declaration }} {{ constant.name | id }} = {{ constant | constant_value }}; {%- endfor -%} {%- if composite_type.inner_type is UnionType -%} -{%- ifuses "std_variant" -%} -{% include '_fields_as_variant.j2' %} -{%- else -%} -{% include '_fields_as_union.j2' %} -{%- endifuses -%} {%- for field in composite_type.fields_except_padding %} bool is_{{field.name|id}}() const { return VariantType::IndexOf::{{field.name|id}} == union_value.index(); @@ -219,9 +251,10 @@ struct {% if composite_type.deprecated -%} template typename std::add_lvalue_reference<_traits_::TypeOf::{{field.name|id}}>::type set_{{field.name|id}}(Args&&...v){ - return union_value.emplace(v...); + return union_value.emplace(std::forward(v)...); } -{%- endfor %} +{% endfor %} + VariantType union_value; {%- else -%} {% include '_fields.j2' %} {%- endif %} diff --git a/src/nunavut/lang/cpp/templates/_fields_as_union.j2 b/src/nunavut/lang/cpp/templates/_fields_as_union.j2 index 57a6fdf7..e10c0297 100644 --- a/src/nunavut/lang/cpp/templates/_fields_as_union.j2 +++ b/src/nunavut/lang/cpp/templates/_fields_as_union.j2 @@ -18,22 +18,49 @@ public: static const constexpr std::size_t variant_npos = std::numeric_limits::max(); + struct IndexOf final + { + IndexOf() = delete; +{%- for field in composite_type.fields_except_padding %} + static constexpr const std::size_t {{ field.name | id }} = {{ loop.index0 }}U; +{%- endfor %} + }; + + static constexpr const std::size_t MAX_INDEX = {{ composite_type.fields_except_padding | length }}U; + + template + using is_index = std::conditional_t<(I < MAX_INDEX), std::true_type, std::false_type>; + + template + static constexpr const bool is_index_v = is_index::value; + + // Default constructor VariantType() - : tag_(0) - , internal_union_value_() + : tag_{0U} + , internal_union_value_{} { // This is how the C++17 standard library does it; default initialization as the 0th index. emplace<0>(); } + // Initializing constructor + template>> + VariantType(nunavut::support::in_place_index_t, Args&&... args) + : tag_{I} + , internal_union_value_{} + { + do_emplace(std::forward(args)...); + } + + // Copy constructor VariantType(const VariantType& rhs) - : tag_(variant_npos) - , internal_union_value_() + : tag_{variant_npos} + , internal_union_value_{} { {%- for field in composite_type.fields_except_padding %} - {% if not loop.first %}else {% endif %}if(rhs.tag_ == {{ loop.index0 }}) + {% if not loop.first %}else {% endif %}if(rhs.tag_ == IndexOf::{{ field | id }}) { - do_copy<{{ loop.index0 }}>( + do_copy( *reinterpret_cast::type>(&rhs.internal_union_value_.{{ field.name | id }}) ); } @@ -41,14 +68,15 @@ tag_ = rhs.tag_; } + // Move constructor VariantType(VariantType&& rhs) - : tag_(variant_npos) - , internal_union_value_() + : tag_{variant_npos} + , internal_union_value_{} { {%- for field in composite_type.fields_except_padding %} - {% if not loop.first %}else {% endif %}if(rhs.tag_ == {{ loop.index0 }}) + {% if not loop.first %}else {% endif %}if(rhs.tag_ == IndexOf::{{ field | id }}) { - do_emplace<{{ loop.index0 }}>( + do_emplace( std::forward<{{ field.data_type | declaration }}>( *reinterpret_cast::type>(&rhs.internal_union_value_.{{ field.name | id }}) ) @@ -57,13 +85,15 @@ {%- endfor %} tag_ = rhs.tag_; } + + // Copy assignment VariantType& operator=(const VariantType& rhs) { destroy_current(); {%- for field in composite_type.fields_except_padding %} - {% if not loop.first %}else {% endif %}if(rhs.tag_ == {{ loop.index0 }}) + {% if not loop.first %}else {% endif %}if(rhs.tag_ == IndexOf::{{ field | id }}) { - do_copy<{{ loop.index0 }}>( + do_copy( *reinterpret_cast::type>(&rhs.internal_union_value_.{{ field.name | id }}) ); } @@ -72,13 +102,14 @@ return *this; } + // Move assignment VariantType& operator=(VariantType&& rhs) { destroy_current(); {%- for field in composite_type.fields_except_padding %} - {% if not loop.first %}else {% endif %}if(rhs.tag_ == {{ loop.index0 }}) + {% if not loop.first %}else {% endif %}if(rhs.tag_ == IndexOf::{{ field | id }}) { - do_emplace<{{ loop.index0 }}>( + do_emplace( std::forward<{{ field.data_type | declaration }}>( *reinterpret_cast::type>(&rhs.internal_union_value_.{{ field.name | id }}) ) @@ -98,19 +129,10 @@ return tag_; } - struct IndexOf final - { - IndexOf() = delete; -{%- for field in composite_type.fields_except_padding %} - static constexpr const std::size_t {{ field.name | id }} = {{ loop.index0 }}U; -{%- endfor %} - }; - static constexpr const std::size_t MAX_INDEX = {{ composite_type.fields_except_padding | length }}U; - template struct alternative; {% for field in composite_type.fields_except_padding %} - template struct alternative<{{ loop.index0 }}U, Types...> + template struct alternative { using type = {{ field.data_type | declaration }}; static constexpr auto pointer = &VariantType::internal_union_t::{{ field.name | id }}; @@ -120,7 +142,7 @@ template typename VariantType::alternative::type& emplace(Args&&... v) { destroy_current(); - typename alternative::type& result = do_emplace(v...); + typename alternative::type& result = do_emplace(std::forward(v)...); tag_ = I; return result; } @@ -171,7 +193,7 @@ void destroy_current() { {%- for field in composite_type.fields_except_padding if field is not PrimitiveType %} - {% if not loop.first %}else {% endif %}if (tag_ == {{ loop.index0 }}) + {% if not loop.first %}else {% endif %}if (tag_ == IndexOf::{{ field | id }}) { reinterpret_cast<{{ field.data_type | declaration }}*>(std::addressof(internal_union_value_.{{ field.name | id }}))->{{ field.data_type | destructor_name }}(); } @@ -179,5 +201,3 @@ } }; - - VariantType union_value; diff --git a/src/nunavut/lang/cpp/templates/_fields_as_variant.j2 b/src/nunavut/lang/cpp/templates/_fields_as_variant.j2 index 94cbde60..18150aab 100644 --- a/src/nunavut/lang/cpp/templates/_fields_as_variant.j2 +++ b/src/nunavut/lang/cpp/templates/_fields_as_variant.j2 @@ -12,6 +12,12 @@ { public: + using Base = std::variant< +{%- for field in composite_type.fields_except_padding %} + _traits_::TypeOf::{{field.name|id}}{% if not loop.last %},{% endif %} +{%- endfor %} + >; + static const constexpr std::size_t variant_npos = std::variant_npos; struct IndexOf final @@ -21,8 +27,15 @@ static constexpr const std::size_t {{ field.name | id }} = {{ loop.index0 }}U; {%- endfor %} }; + static constexpr const std::size_t MAX_INDEX = {{ composite_type.fields_except_padding | length }}U; + template + using is_index = std::conditional_t<(I < MAX_INDEX), std::true_type, std::false_type>; + + template + static constexpr const bool is_index_v = is_index::value; + template struct alternative; @@ -38,6 +51,42 @@ using type = std::add_const_t::type>; }; + // Default constructor + VariantType() = default; + + // Initializing constructor + template>> + VariantType(nunavut::support::in_place_index_t, Args&&... args) : + Base{std::in_place_index_t{}, std::forward(args)...} + {} + + // Copy constructor + VariantType(const VariantType& rhs) : + Base{rhs} + {} + + // Move constructor + VariantType(VariantType&& rhs) : + Base{std::move(rhs)} + {} + + // Copy assignment + VariantType& operator=(const VariantType& rhs) + { + (void)Base::operator=(rhs); // avoid unused return value warning + return *this; + } + + // Move assignment + VariantType& operator=(VariantType&& rhs) + { + (void)Base::operator=(std::move(rhs)); // avoid unused return value warning + return *this; + } + + // Destructor + ~VariantType() = default; + template static constexpr typename alternative>::type* get_if(std::variant* v) noexcept { @@ -51,5 +100,3 @@ } }; - - VariantType union_value; diff --git a/verification/cpp/suite/test_unionant.cpp b/verification/cpp/suite/test_unionant.cpp index 9710d0fb..c6d988ea 100644 --- a/verification/cpp/suite/test_unionant.cpp +++ b/verification/cpp/suite/test_unionant.cpp @@ -66,22 +66,69 @@ TEST(UnionantTests, get_set_lvalue) } } +/** + * Verify that the variant value can be fetched only as the type being held + */ TEST(UnionantTests, get_if_const_variant) { using ValueType = uavcan::_register::Value_1_0; - const uavcan::_register::Value_1_0 a{}; - const uavcan::primitive::array::Integer32_1_0* p = + const ValueType a{}; + + const uavcan::primitive::Empty_1_0* p_empty = + uavcan::_register::Value_1_0::VariantType::get_if(&a.union_value); + const uavcan::primitive::array::Integer32_1_0* p_int32 = uavcan::_register::Value_1_0::VariantType::get_if(&a.union_value); - ASSERT_EQ(nullptr, p); + + ASSERT_NE(nullptr, p_empty); + ASSERT_EQ(nullptr, p_int32); } +/** + * Verify that the variant value can be fetched only at the alternative index for the type being held + */ TEST(UnionantTests, union_with_same_types) { using ValueType = regulated::basics::UnionWithSameTypes_0_1; - ValueType a{}; - std::array* p = + ValueType a{}; + + regulated::basics::Struct__0_1* p_struct1 = + ValueType::VariantType::get_if(&a.union_value); + regulated::basics::Struct__0_1* p_struct2 = + ValueType::VariantType::get_if(&a.union_value); + std::array* p_delimited_fix_le2 = ValueType::VariantType::get_if(&a.union_value); - ASSERT_EQ(nullptr, p); + + ASSERT_NE(nullptr, p_struct1); + ASSERT_EQ(nullptr, p_struct2); + ASSERT_EQ(nullptr, p_delimited_fix_le2); +} + +/** + * Verify the initializing constructor of the VariantType + */ +TEST(UnionantTests, union_value_init_ctor) +{ + using ValueType = uavcan::_register::Value_1_0; + uavcan::primitive::array::Integer32_1_0 v{{1, 2, 3}}; + const ValueType::VariantType a{ + nunavut::support::in_place_index_t{}, + v + }; + + const uavcan::primitive::Empty_1_0* p_empty = + uavcan::_register::Value_1_0::VariantType::get_if(&a); + const uavcan::primitive::array::Integer32_1_0* p_int32 = + uavcan::_register::Value_1_0::VariantType::get_if(&a); + + ASSERT_EQ(nullptr, p_empty); + ASSERT_NE(nullptr, p_int32); + + if (p_int32 != nullptr) + { + ASSERT_EQ(p_int32->value[0], 1); + ASSERT_EQ(p_int32->value[1], 2); + ASSERT_EQ(p_int32->value[2], 3); + } } /**