From c0355f26fc557b36d5246ab24645ee426cba63a1 Mon Sep 17 00:00:00 2001 From: Lukas Zanner Date: Sun, 25 Feb 2024 00:28:16 +0100 Subject: [PATCH] Add automatic type conversion to row::get() --- include/soci/row.h | 2 +- include/soci/soci-platform.h | 3 + include/soci/type-holder.h | 325 ++++++++++++++++++++++++--- tests/common-tests.h | 116 ++++++++++ tests/mysql/test-mysql.cpp | 3 + tests/postgresql/test-postgresql.cpp | 2 +- 6 files changed, 420 insertions(+), 31 deletions(-) diff --git a/include/soci/row.h b/include/soci/row.h index 3279dae25..fd0e53954 100644 --- a/include/soci/row.h +++ b/include/soci/row.h @@ -63,7 +63,7 @@ class SOCI_DECL row template inline void add_holder(T* t, indicator* ind) { - holders_.push_back(new details::type_holder(t)); + holders_.push_back(details::holder::make_holder(t)); indicators_.push_back(ind); } diff --git a/include/soci/soci-platform.h b/include/soci/soci-platform.h index bb2b8eeff..125f772fe 100644 --- a/include/soci/soci-platform.h +++ b/include/soci/soci-platform.h @@ -37,6 +37,9 @@ //base class must have dll interface #pragma warning(disable:4251 4275) +// As long as we don't require C++17, we must disable the warning +// "conditional expression is constant" +#pragma warning(disable:4127) // Define if you have the vsnprintf variants. #if _MSC_VER < 1500 diff --git a/include/soci/type-holder.h b/include/soci/type-holder.h index f103edb58..2c9e88cb4 100644 --- a/include/soci/type-holder.h +++ b/include/soci/type-holder.h @@ -8,9 +8,15 @@ #ifndef SOCI_TYPE_HOLDER_H_INCLUDED #define SOCI_TYPE_HOLDER_H_INCLUDED +#include "soci/soci-backend.h" #include "soci/soci-platform.h" -// std -#include +#include "soci/soci-types.h" + +#include +#include +#include +#include +#include #include namespace soci @@ -44,49 +50,310 @@ T* checked_ptr_cast(U* ptr) return static_cast(ptr); } -// Base class holder + derived class type_holder for storing type data -// instances in a container of holder objects -template -class type_holder; - -class holder +// Type safe conversion that fails at compilation if instantiated +template +struct soci_cast { -public: - holder() {} - virtual ~holder() {} + static inline T cast(U) + { + throw std::bad_cast(); + } +}; - template - T get() +// Type safe conversion that is a noop +template +struct soci_cast< + T, U, + typename std::enable_if::value>::type> +{ + static inline T cast(U val) { - type_holder* p = checked_ptr_cast >(this); - if (p) - { - return p->template value(); - } - else + return val; + } +}; + +// Type safe conversion that is widening the type +template +struct soci_cast< + T, U, + typename std::enable_if<( + !std::is_same::value && + std::is_integral::value && + std::is_integral::value + )>::type> +{ + static inline T cast(U val) { + intmax_t t_min = static_cast((std::numeric_limits::min)()); + intmax_t u_min = static_cast((std::numeric_limits::min)()); + uintmax_t t_max = static_cast((std::numeric_limits::max)()); + uintmax_t u_max = static_cast((std::numeric_limits::max)()); + + if ((t_min > u_min && val < static_cast(t_min)) || + (t_max < u_max && val > static_cast(t_max))) { throw std::bad_cast(); } - } -private: + return static_cast(val); + } +}; - template - T value(); +union type_holder +{ + std::string* s; + int8_t* i8; + int16_t* i16; + int32_t* i32; + int64_t* i64; + uint8_t* u8; + uint16_t* u16; + uint32_t* u32; + uint64_t* u64; + double* d; + std::tm* t; }; template -class type_holder : public holder +struct type_holder_trait +{ + static_assert(std::is_same::value, "Unmatched raw type"); + // dummy value to satisfy the template engine, never used + static const db_type type = (db_type)0; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_string; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int8; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int16; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int32; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_int64; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint8; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint16; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint32; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_uint64; +}; + +#if defined(SOCI_INT64_IS_LONG) +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#elif defined(SOCI_LONG_IS_64_BIT) +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#else +template <> +struct type_holder_trait : type_holder_trait +{ +}; + +template <> +struct type_holder_trait : type_holder_trait +{ +}; +#endif + +template <> +struct type_holder_trait +{ + static const db_type type = db_double; +}; + +template <> +struct type_holder_trait +{ + static const db_type type = db_date; +}; + +// Base class for storing type data instances in a container of holder objects +class holder { public: - type_holder(T * t) : t_(t) {} - ~type_holder() override { delete t_; } + template + static holder* make_holder(T* val) + { + return new holder(type_holder_trait::type, val); + } - template - TypeValue value() const { return *t_; } + ~holder() + { + switch (dt_) + { + case db_double: + delete val_.d; + break; + case db_int8: + delete val_.i8; + break; + case db_int16: + delete val_.i16; + break; + case db_int32: + delete val_.i32; + break; + case db_int64: + delete val_.i64; + break; + case db_uint8: + delete val_.u8; + break; + case db_uint16: + delete val_.u16; + break; + case db_uint32: + delete val_.u32; + break; + case db_uint64: + delete val_.u64; + break; + case db_date: + delete val_.t; + break; + case db_blob: + case db_xml: + case db_string: + delete val_.s; + break; + default: + break; + } + } + + template + T get() + { + switch (dt_) + { + case db_int8: + return soci_cast::cast(*val_.i8); + case db_int16: + return soci_cast::cast(*val_.i16); + case db_int32: + return soci_cast::cast(*val_.i32); + case db_int64: + return soci_cast::cast(*val_.i64); + case db_uint8: + return soci_cast::cast(*val_.u8); + case db_uint16: + return soci_cast::cast(*val_.u16); + case db_uint32: + return soci_cast::cast(*val_.u32); + case db_uint64: + return soci_cast::cast(*val_.u64); + case db_double: + return soci_cast::cast(*val_.d); + case db_date: + return soci_cast::cast(*val_.t); + case db_blob: + case db_xml: + case db_string: + return soci_cast::cast(*val_.s); + default: + throw std::bad_cast(); + } + } private: - T * t_; + holder(db_type dt, void* val) : dt_(dt) + { + switch (dt_) + { + case db_double: + val_.d = static_cast(val); + break; + case db_int8: + val_.i8 = static_cast(val); + break; + case db_int16: + val_.i16 = static_cast(val); + break; + case db_int32: + val_.i32 = static_cast(val); + break; + case db_int64: + val_.i64 = static_cast(val); + break; + case db_uint8: + val_.u8 = static_cast(val); + break; + case db_uint16: + val_.u16 = static_cast(val); + break; + case db_uint32: + val_.u32 = static_cast(val); + break; + case db_uint64: + val_.u64 = static_cast(val); + break; + case db_date: + val_.t = static_cast(val); + break; + case db_blob: + case db_xml: + case db_string: + val_.s = static_cast(val); + break; + default: + break; + } + } + + const db_type dt_; + type_holder val_; }; } // namespace details diff --git a/tests/common-tests.h b/tests/common-tests.h index 9fd7b755c..8717c863c 100644 --- a/tests/common-tests.h +++ b/tests/common-tests.h @@ -3317,6 +3317,122 @@ TEST_CASE_METHOD(common_tests, "Dynamic binding with type conversions", "[core][ } } +// Dynamic bindings with type casts +TEST_CASE_METHOD(common_tests, "Dynamic row binding 4", "[core][dynamic]") +{ + soci::session sql(backEndFactory_, connectString_); + + SECTION("simple type cast") + { + auto_table_creator tableCreator(tc_.table_creator_1(sql)); + + sql << "insert into soci_test(id, d, str, tm)" + << " values(10, 20.0, 'foobar'," + << tc_.to_date_time("2005-12-19 22:14:17") + << ")"; + + { + row r; + sql << "select id from soci_test", into(r); + + static_assert(std::numeric_limits::digits10 == 9, ""); + CHECK(r.size() == 1); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK(r.get(0) == 10); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select d from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + ASSERT_EQUAL_APPROX(r.get(0), 20.0); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select str from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == "foobar"); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + } + { + row r; + sql << "select tm from soci_test", into(r); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0).tm_year == 105); + CHECK(r.get(0).tm_mon == 11); + CHECK(r.get(0).tm_mday == 19); + CHECK(r.get(0).tm_hour == 22); + CHECK(r.get(0).tm_min == 14); + CHECK(r.get(0).tm_sec == 17); + } + } + SECTION("overflowing type cast") + { + auto_table_creator tableCreator(tc_.table_creator_1(sql)); + + sql << "insert into soci_test(id)" + << " values(" + << (std::numeric_limits::max)() + << ")"; + + row r; + sql << "select id from soci_test", into(r); + + intmax_t v = (intmax_t)(std::numeric_limits::max)(); + uintmax_t uv = (uintmax_t)(std::numeric_limits::max)(); + + CHECK(r.size() == 1); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == v); + CHECK(r.get(0) == v); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK_THROWS_AS(r.get(0), std::bad_cast); + CHECK(r.get(0) == uv); + CHECK(r.get(0) == uv); + } +} + TEST_CASE_METHOD(common_tests, "Prepared insert with ORM", "[core][orm]") { soci::session sql(backEndFactory_, connectString_); diff --git a/tests/mysql/test-mysql.cpp b/tests/mysql/test-mysql.cpp index 7e933432e..627b50e84 100644 --- a/tests/mysql/test-mysql.cpp +++ b/tests/mysql/test-mysql.cpp @@ -702,6 +702,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_long_long); CHECK(r.get_properties("val").get_db_type() == db_uint32); + CHECK(r.get("val") == 0xffffff00); CHECK(r.get("val") == 0xffffff00); CHECK(r.get("val") == 0xffffff00); } @@ -714,6 +715,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_integer); CHECK(r.get_properties("val").get_db_type() == db_int8); + CHECK(r.get("val") == -123); CHECK(r.get("val") == -123); } { @@ -725,6 +727,7 @@ TEST_CASE("MySQL tinyint", "[mysql][int][tinyint]") REQUIRE(r.size() == 1); CHECK(r.get_properties("val").get_data_type() == dt_integer); CHECK(r.get_properties("val").get_db_type() == db_uint8); + CHECK(r.get("val") == 123); CHECK(r.get("val") == 123); } { diff --git a/tests/postgresql/test-postgresql.cpp b/tests/postgresql/test-postgresql.cpp index a69459c49..8a1178634 100644 --- a/tests/postgresql/test-postgresql.cpp +++ b/tests/postgresql/test-postgresql.cpp @@ -1214,7 +1214,7 @@ struct test_enum_with_explicit_custom_type_int_rowset : table_creator_base try { - sql << "CREATE TABLE soci_test( Type integer)"; + sql << "CREATE TABLE soci_test( Type smallint)"; ; } catch (...)