Skip to content

Commit

Permalink
Revert "Refactor thrust::complex as a struct derived from `cuda::st…
Browse files Browse the repository at this point in the history
…d::complex` (#454)" (#1286)

Co-authored-by: Jake Hemstad <jhemstad@nvidia.com>
  • Loading branch information
miscco and jrhemstad authored Jan 17, 2024
1 parent 2d9b6ff commit 0df3163
Show file tree
Hide file tree
Showing 26 changed files with 5,668 additions and 539 deletions.
64 changes: 2 additions & 62 deletions thrust/testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,6 @@ struct TestComplexSizeAndAlignment
};
SimpleUnitTest<TestComplexSizeAndAlignment, FloatingPointTypes> TestComplexSizeAndAlignmentInstance;

template <typename T>
struct TestComplexTypeCheck
{
void operator()()
{
THRUST_STATIC_ASSERT(thrust::is_complex<thrust::complex<T>>::value);
THRUST_STATIC_ASSERT(thrust::is_complex<std::complex<T>>::value);
THRUST_STATIC_ASSERT(thrust::is_complex<cuda::std::complex<T>>::value);
}
};
SimpleUnitTest<TestComplexTypeCheck, FloatingPointTypes> TestComplexTypeCheckInstance;

template <typename T>
struct TestComplexConstructionAndAssignment
{
Expand Down Expand Up @@ -449,18 +437,17 @@ struct TestComplexBasicArithmetic
// Test the basic arithmetic functions against std

ASSERT_ALMOST_EQUAL(thrust::abs(a), std::abs(b));

ASSERT_ALMOST_EQUAL(thrust::arg(a), std::arg(b));

ASSERT_ALMOST_EQUAL(thrust::norm(a), std::norm(b));

ASSERT_EQUAL(thrust::conj(a), std::conj(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::conj(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::polar(data[0], data[1]), std::polar(data[0], data[1]));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::polar(data[0], data[1]))>::value, "");

// random_samples does not seem to produce infinities so proj(z) == z
ASSERT_EQUAL(thrust::proj(a), a);
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::proj(a))>::value, "");
}
};
SimpleUnitTest<TestComplexBasicArithmetic, FloatingPointTypes> TestComplexBasicArithmeticInstance;
Expand Down Expand Up @@ -557,9 +544,6 @@ struct TestComplexExponentialFunctions
ASSERT_ALMOST_EQUAL(thrust::exp(a), std::exp(b));
ASSERT_ALMOST_EQUAL(thrust::log(a), std::log(b));
ASSERT_ALMOST_EQUAL(thrust::log10(a), std::log10(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::exp(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log10(a))>::value, "");
}
};
SimpleUnitTest<TestComplexExponentialFunctions, FloatingPointTypes>
Expand All @@ -579,24 +563,16 @@ struct TestComplexPowerFunctions
const std::complex<T> b_std(b_thrust);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, 4), std::pow(a_std, 4));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, 4))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::sqrt(a_thrust), std::sqrt(a_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sqrt(a_thrust))>::value, "");
}

// Test power functions with promoted types.
{
using T0 = T;
using T1 = other_floating_point_type_t<T0>;
using promoted = typename thrust::detail::promoted_numerical_type<T0, T1>::type;

thrust::host_vector<T0> data = unittest::random_samples<T0>(4);

Expand All @@ -606,17 +582,11 @@ struct TestComplexPowerFunctions
const std::complex<T0> b_std(data[2], data[3]);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust), std::pow(b_std, a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust.real()), std::pow(b_std, a_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust.real(), a_thrust), std::pow(b_std.real(), a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust.real(), a_thrust))>::value, "");
}
}
};
Expand All @@ -635,32 +605,20 @@ struct TestComplexTrigonometricFunctions
ASSERT_ALMOST_EQUAL(thrust::cos(a), std::cos(c));
ASSERT_ALMOST_EQUAL(thrust::sin(a), std::sin(c));
ASSERT_ALMOST_EQUAL(thrust::tan(a), std::tan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::cosh(a), std::cosh(c));
ASSERT_ALMOST_EQUAL(thrust::sinh(a), std::sinh(c));
ASSERT_ALMOST_EQUAL(thrust::tanh(a), std::tanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tanh(a))>::value, "");

#if THRUST_CPP_DIALECT >= 2011

ASSERT_ALMOST_EQUAL(thrust::acos(a), std::acos(c));
ASSERT_ALMOST_EQUAL(thrust::asin(a), std::asin(c));
ASSERT_ALMOST_EQUAL(thrust::atan(a), std::atan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::acosh(a), std::acosh(c));
ASSERT_ALMOST_EQUAL(thrust::asinh(a), std::asinh(c));
ASSERT_ALMOST_EQUAL(thrust::atanh(a), std::atanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atanh(a))>::value, "");

#endif
}
Expand Down Expand Up @@ -709,21 +667,3 @@ struct TestComplexStdComplexDeviceInterop
SimpleUnitTest<TestComplexStdComplexDeviceInterop, FloatingPointTypes>
TestComplexStdComplexDeviceInteropInstance;
#endif

template <typename T>
struct TestComplexExplicitConstruction
{
struct user_complex {
__host__ __device__ user_complex(T, T) {}
__host__ __device__ user_complex(const thrust::complex<T>&) {}
};

void operator()()
{
const thrust::complex<T> input(42.0, 1337.0);
const user_complex result = thrust::exp(input);
(void)result;
}
};
SimpleUnitTest<TestComplexExplicitConstruction, FloatingPointTypes>
TestComplexExplicitConstructionInstance;
29 changes: 23 additions & 6 deletions thrust/testing/unittest/assertions.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,31 @@ bool almost_equal(double a, double b, double a_tol, double r_tol)
return true;
}

namespace
{ // anonymous namespace

template <typename>
struct is_complex : public THRUST_NS_QUALIFIER::false_type
{};

template <typename T>
struct is_complex<THRUST_NS_QUALIFIER::complex<T>> : public THRUST_NS_QUALIFIER::true_type
{};

template <typename T>
struct is_complex<std::complex<T>> : public THRUST_NS_QUALIFIER::true_type
{};

} // namespace

template <typename T1, typename T2>
typename THRUST_NS_QUALIFIER::detail::enable_if<THRUST_NS_QUALIFIER::is_complex<T1>::value &&
THRUST_NS_QUALIFIER::is_complex<T2>::value,
bool>::type
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
inline
typename THRUST_NS_QUALIFIER::detail::enable_if<is_complex<T1>::value && is_complex<T2>::value,
bool>::type
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
{
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
}

template <typename T1, typename T2>
Expand Down
Loading

0 comments on commit 0df3163

Please sign in to comment.