From ae07cd546ee3f21b7dc1be6ff36be5593ccb4a67 Mon Sep 17 00:00:00 2001 From: Lukas Zanner Date: Sun, 25 Feb 2024 00:28:16 +0100 Subject: [PATCH] Add automatic type conversion to row::get() Convert to the requested type if lossless conversion is possible to make the behaviour more compatible with the previous SOCI versions and more useful. See #1127. --- include/soci/row.h | 8 +- include/soci/type-holder.h | 412 +++++++++++++++++++++++++-- src/core/row.cpp | 3 +- tests/common-tests.h | 115 ++++++++ tests/mysql/test-mysql.cpp | 3 + tests/postgresql/test-postgresql.cpp | 2 +- 6 files changed, 507 insertions(+), 36 deletions(-) diff --git a/include/soci/row.h b/include/soci/row.h index e76aea403..07f7c3432 100644 --- a/include/soci/row.h +++ b/include/soci/row.h @@ -65,7 +65,7 @@ class SOCI_DECL row template inline void add_holder(T* t, indicator* ind) { - holders_.push_back(new details::type_holder(t)); + holders_.push_back(details::holder::make_holder(t)); indicators_.push_back(ind); } @@ -78,7 +78,8 @@ class SOCI_DECL row typedef typename type_conversion::base_type base_type; static_assert(details::can_use_from_base>(), "Can't use row::get() with this type (not convertible/copy-assignable from base_type) - did you mean to use move_as?"); - base_type const& baseVal = holders_.at(pos)->get(); + base_type const& baseVal = + holders_.at(pos)->get(details::value_cast_tag{}); T ret; type_conversion::from_base(baseVal, *indicators_.at(pos), ret); @@ -91,7 +92,8 @@ class SOCI_DECL row typedef typename type_conversion::base_type base_type; static_assert(details::can_use_move_from_base(), "row::move_as() can only be called with types that can be instantiated from a base type rvalue reference"); - base_type & baseVal = holders_.at(pos)->get(); + base_type & baseVal = + holders_.at(pos)->get(details::value_reference_tag{}); T ret; type_conversion::move_from_base(baseVal, *indicators_.at(pos), ret); diff --git a/include/soci/type-holder.h b/include/soci/type-holder.h index a9e189d23..73badaa14 100644 --- a/include/soci/type-holder.h +++ b/include/soci/type-holder.h @@ -8,9 +8,16 @@ #ifndef SOCI_TYPE_HOLDER_H_INCLUDED #define SOCI_TYPE_HOLDER_H_INCLUDED -#include "soci/soci-platform.h" -// std -#include +#include "soci/blob.h" +#include "soci/error.h" +#include "soci/soci-backend.h" +#include "soci/soci-types.h" + +#include +#include +#include +#include +#include #include namespace soci @@ -44,52 +51,395 @@ T* checked_ptr_cast(U* ptr) return static_cast(ptr); } -// Base class holder + derived class type_holder for storing type data -// instances in a container of holder objects -template -class type_holder; +template +struct soci_return_same +{ + static inline T& value(U&) + { + throw std::bad_cast(); + } +}; -class holder +template +struct soci_return_same< + T, U, + typename std::enable_if::value>::type> { -public: - holder() {} - virtual ~holder() {} + static inline T& value(U& val) + { + return val; + } +}; - template - T &get() +// Type safe conversion that throws if the types are mismatched +template +struct soci_cast +{ + static inline T cast(U) { - type_holder* p = checked_ptr_cast >(this); - if (p) - { - return p->template value(); - } - else + throw std::bad_cast(); + } +}; + +// Type safe conversion that is a noop +template +struct soci_cast< + T, U, + typename std::enable_if::value>::type> +{ + static inline T cast(U val) + { + return val; + } +}; + +// Type safe conversion that is widening the type +template +struct soci_cast< + T, U, + typename std::enable_if<( + !std::is_same::value && + std::is_integral::value && + std::is_integral::value + )>::type> +{ + static inline T cast(U val) { + intmax_t t_min = static_cast((std::numeric_limits::min)()); + intmax_t u_min = static_cast((std::numeric_limits::min)()); + uintmax_t t_max = static_cast((std::numeric_limits::max)()); + uintmax_t u_max = static_cast((std::numeric_limits::max)()); + +#ifdef _MSC_VER +// As long as we don't require C++17, we must disable the warning +// "conditional expression is constant" as it can give false positives here. +#pragma warning(push) +#pragma warning(disable:4127) +#endif + if ((t_min > u_min && val < static_cast(t_min)) || + (t_max < u_max && val > static_cast(t_max))) { throw std::bad_cast(); } - } +#ifdef _MSC_VER +#pragma warning(pop) +#endif -private: + return static_cast(val); + } +}; - template - T value(); +union type_holder +{ + std::string* s; + int8_t* i8; + int16_t* i16; + int32_t* i32; + int64_t* i64; + uint8_t* u8; + uint16_t* u16; + uint32_t* u32; + uint64_t* u64; + double* d; + std::tm* t; + blob* b; }; template -class type_holder : public holder +struct type_holder_trait; + +template <> +struct type_holder_trait +{ + static const db_type type = db_string; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int8; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int16; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int32; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int64; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint8; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint16; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint32; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint64; +}; + +#if defined(SOCI_INT64_IS_LONG) +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#elif defined(SOCI_LONG_IS_64_BIT) +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#else +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#endif + +template <> +struct type_holder_trait +{ + static const db_type type = db_double; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_date; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_blob; +}; + +struct value_cast_tag{}; +struct value_reference_tag{}; + +// Class for storing type data instances in a container of holder objects +class holder { public: - type_holder(T * t) : t_(t) {} - ~type_holder() override { delete t_; } + template + static holder* make_holder(T* val) + { + return new holder(type_holder_trait::type, val); + } - template - const TypeValue &value() const { return *t_; } + ~holder() + { + switch (dt_) + { + case db_double: + delete val_.d; + break; + case db_int8: + delete val_.i8; + break; + case db_int16: + delete val_.i16; + break; + case db_int32: + delete val_.i32; + break; + case db_int64: + delete val_.i64; + break; + case db_uint8: + delete val_.u8; + break; + case db_uint16: + delete val_.u16; + break; + case db_uint32: + delete val_.u32; + break; + case db_uint64: + delete val_.u64; + break; + case db_date: + delete val_.t; + break; + case db_blob: + delete val_.b; + break; + case db_xml: + case db_string: + delete val_.s; + break; + } + } - template - TypeValue &value() { return *t_; } +#ifdef _MSC_VER +// MSVC complains about "unreachable code" even though all +// code here can be reached. +#pragma warning(push) +#pragma warning(disable:4702) +#endif + template + T get(value_cast_tag) + { + switch (dt_) + { + case db_int8: + return soci_cast::cast(*val_.i8); + case db_int16: + return soci_cast::cast(*val_.i16); + case db_int32: + return soci_cast::cast(*val_.i32); + case db_int64: + return soci_cast::cast(*val_.i64); + case db_uint8: + return soci_cast::cast(*val_.u8); + case db_uint16: + return soci_cast::cast(*val_.u16); + case db_uint32: + return soci_cast::cast(*val_.u32); + case db_uint64: + return soci_cast::cast(*val_.u64); + case db_double: + return soci_cast::cast(*val_.d); + case db_date: + return soci_cast::cast(*val_.t); + case db_blob: + // blob is not copyable + break; + case db_xml: + case db_string: + return soci_cast::cast(*val_.s); + } + + throw std::bad_cast(); + } + + template + T& get(value_reference_tag) + { + switch (dt_) + { + case db_int8: + return soci_return_same::value(*val_.i8); + case db_int16: + return soci_return_same::value(*val_.i16); + case db_int32: + return soci_return_same::value(*val_.i32); + case db_int64: + return soci_return_same::value(*val_.i64); + case db_uint8: + return soci_return_same::value(*val_.u8); + case db_uint16: + return soci_return_same::value(*val_.u16); + case db_uint32: + return soci_return_same::value(*val_.u32); + case db_uint64: + return soci_return_same::value(*val_.u64); + case db_double: + return soci_return_same::value(*val_.d); + case db_date: + return soci_return_same::value(*val_.t); + case db_blob: + return soci_return_same::value(*val_.b); + case db_xml: + case db_string: + return soci_return_same::value(*val_.s); + } + + throw std::bad_cast(); + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif private: - T * t_; + holder(db_type dt, void* val) : dt_(dt) + { + switch (dt_) + { + case db_double: + val_.d = static_cast(val); + return; + case db_int8: + val_.i8 = static_cast(val); + return; + case db_int16: + val_.i16 = static_cast(val); + return; + case db_int32: + val_.i32 = static_cast(val); + return; + case db_int64: + val_.i64 = static_cast(val); + return; + case db_uint8: + val_.u8 = static_cast(val); + return; + case db_uint16: + val_.u16 = static_cast(val); + return; + case db_uint32: + val_.u32 = static_cast(val); + return; + case db_uint64: + val_.u64 = static_cast(val); + return; + case db_date: + val_.t = static_cast(val); + return; + case db_blob: + val_.b = static_cast(val); + return; + case db_xml: + case db_string: + val_.s = static_cast(val); + return; + } + + // This should be unreachable + std::ostringstream ss; + ss << "Created holder with unsupported type " << std::to_string(dt); + throw soci_error(ss.str()); + } + + const db_type dt_; + type_holder val_; }; } // namespace details diff --git a/src/core/row.cpp b/src/core/row.cpp index da145deb7..af4bdc745 100644 --- a/src/core/row.cpp +++ b/src/core/row.cpp @@ -7,6 +7,7 @@ #define SOCI_SOURCE #include "soci/row.h" +#include "soci/type-holder.h" #include #include @@ -114,7 +115,7 @@ template <> blob row::move_as(std::size_t pos) const { typedef typename type_conversion::base_type base_type; - base_type & baseVal = holders_.at(pos)->get(); + base_type & baseVal = holders_.at(pos)->get(value_reference_tag{}); blob ret; type_conversion::move_from_base(baseVal, *indicators_.at(pos), ret); diff --git a/tests/common-tests.h b/tests/common-tests.h index c3d222c6f..d2b8fb42c 100644 --- a/tests/common-tests.h +++ b/tests/common-tests.h @@ -3323,6 +3323,121 @@ TEST_CASE_METHOD(common_tests, "Dynamic binding with type conversions", "[core][ } } +// Dynamic bindings with type casts +TEST_CASE_METHOD(common_tests, "Dynamic row binding 4", "[core][dynamic]") +{ + soci::session sql(backEndFactory_, connectString_); + + SECTION("simple type cast") + { + auto_table_creator tableCreator(tc_.table_creator_1(sql)); + + sql << "insert into soci_test(id, d, str, tm)" + << " values(10, 20.0, 'foobar'," + << tc_.to_date_time("2005-12-19 22:14:17") + << ")"; + + { + row r; + sql << "select id from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select d from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + ASSERT_EQUAL_APPROX(r.get(0), 20.0); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select str from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == "foobar"); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select tm from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0).tm_year == 105); + CHECK(r.get(0).tm_mon == 11); + CHECK(r.get(0).tm_mday == 19); + CHECK(r.get(0).tm_hour == 22); + CHECK(r.get(0).tm_min == 14); + CHECK(r.get(0).tm_sec == 17); + } + } + SECTION("overflowing type cast") + { + auto_table_creator tableCreator(tc_.table_creator_1(sql)); + + sql << "insert into soci_test(id)" + << " values(" + << (std::numeric_limits::max)() + << ")"; + + row r; + sql << "select id from soci_test", into(r); + + intmax_t v = (intmax_t)(std::numeric_limits::max)(); + uintmax_t uv = (uintmax_t)(std::numeric_limits::max)(); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == v); + CHECK(r.get(0) == v); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == uv); + CHECK(r.get(0) == uv); + } +} + TEST_CASE_METHOD(common_tests, "Prepared insert with ORM", "[core][orm]") { soci::session sql(backEndFactory_, connectString_); diff --git a/tests/mysql/test-mysql.cpp b/tests/mysql/test-mysql.cpp index 7d548d217..cc5c50c5f 100644 --- a/tests/mysql/test-mysql.cpp +++ b/tests/mysql/test-mysql.cpp @@ -637,6 +637,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_long_long); CHECK(r.get_properties("val").get_db_type() == db_uint32); + CHECK(r.get("val") == 0xffffff00); CHECK(r.get("val") == 0xffffff00); CHECK(r.get("val") == 0xffffff00); } @@ -649,6 +650,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_integer); CHECK(r.get_properties("val").get_db_type() == db_int8); + CHECK(r.get("val") == -123); CHECK(r.get("val") == -123); } { @@ -660,6 +662,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_integer); CHECK(r.get_properties("val").get_db_type() == db_uint8); + CHECK(r.get("val") == 123); CHECK(r.get("val") == 123); } { diff --git a/tests/postgresql/test-postgresql.cpp b/tests/postgresql/test-postgresql.cpp index 651685866..16ce3fc8d 100644 --- a/tests/postgresql/test-postgresql.cpp +++ b/tests/postgresql/test-postgresql.cpp @@ -1161,7 +1161,7 @@ struct test_enum_with_explicit_custom_type_int_rowset : table_creator_base try { - sql << "CREATE TABLE soci_test( Type integer)"; + sql << "CREATE TABLE soci_test( Type smallint)"; ; } catch (...)