From 5fe696b1c209321238cbb8cfaa9c402c007cbc22 Mon Sep 17 00:00:00 2001 From: Malcolm Roberts Date: Tue, 23 Jul 2019 16:19:56 -0600 Subject: [PATCH] Fix bug in Bluestein Algorithm and add tests (#220) * Add prime-size tests for 1D c2c transforms * Fix Bluestein algorithm bug --- clients/tests/accuracy_test_1D.cpp | 19 +++++++++++++++---- library/src/plan.cpp | 18 ++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/clients/tests/accuracy_test_1D.cpp b/clients/tests/accuracy_test_1D.cpp index 443cffd7..90b629c9 100644 --- a/clients/tests/accuracy_test_1D.cpp +++ b/clients/tests/accuracy_test_1D.cpp @@ -66,6 +66,8 @@ class accuracy_test_complex_pow2_double : public ::testing::Test #define MIX_RANGE \ 6, 10, 12, 15, 20, 30, 120, 150, 225, 240, 300, 486, 600, 900, 1250, 1500, 1875, 2160, 2187, \ 2250, 2500, 3000, 4000, 12000, 24000, 72000 +#define PRIME_RANGE \ + 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97 #define LARGE_RANGE \ 4096, 4050, 4000, 3888, 3840, 3750, 3645, 3600, 3456, 3375, 3240, 3200, 3125, 3072, 3000, \ @@ -77,10 +79,11 @@ class accuracy_test_complex_pow2_double : public ::testing::Test 160, 150, 144, 135, 128, 125, 120, 108, 100, 96, 90, 81, 80, 75, 72, 64, 60, 54, 50, 48, \ 45, 40, 36, 32, 30, 27, 25,\ 24, 20, 18, 16, 15, 12, 10, 9, 8, 6, 5, 4, 3, 2, 1 -static std::vector pow2_range = {POW2_RANGE}; -static std::vector pow3_range = {POW3_RANGE}; -static std::vector pow5_range = {POW5_RANGE}; -static std::vector mix_range = {MIX_RANGE}; +static std::vector pow2_range = {POW2_RANGE}; +static std::vector pow3_range = {POW3_RANGE}; +static std::vector pow5_range = {POW5_RANGE}; +static std::vector mix_range = {MIX_RANGE}; +static std::vector prime_range = {PRIME_RANGE}; static size_t batch_range[] = {1}; @@ -415,6 +418,14 @@ INSTANTIATE_TEST_CASE_P(rocfft_pow_random_1D, ValuesIn(transform_range), ValuesIn(stride_range))); +INSTANTIATE_TEST_CASE_P(rocfft_prime_1D, + accuracy_test_complex, + Combine(ValuesIn(prime_range), + ValuesIn(batch_range), + ValuesIn(placeness_range), + ValuesIn(transform_range), + ValuesIn(stride_range))); + // ***************************************************** // REAL HERMITIAN // ***************************************************** diff --git a/library/src/plan.cpp b/library/src/plan.cpp index 8386186b..e61430b1 100644 --- a/library/src/plan.cpp +++ b/library/src/plan.cpp @@ -1672,11 +1672,11 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(OperatingBuffer& flipIn, assert(childNodes.size() == 7); assert(childNodes[0]->scheme == CS_KERNEL_CHIRP); - childNodes[0]->obIn = obIn; + childNodes[0]->obIn = OB_TEMP_BLUESTEIN; childNodes[0]->obOut = OB_TEMP_BLUESTEIN; assert(childNodes[1]->scheme == CS_KERNEL_PAD_MUL); - childNodes[1]->obIn = OB_TEMP_BLUESTEIN; + childNodes[1]->obIn = obIn; childNodes[1]->obOut = OB_TEMP_BLUESTEIN; childNodes[2]->obIn = OB_TEMP_BLUESTEIN; @@ -1985,6 +1985,13 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(OperatingBuffer& flipIn, // Assert that the kernel chain is connected for(int i = 1; i < childNodes.size(); ++i) { + if(childNodes[i - 1]->scheme == CS_KERNEL_CHIRP) + { + // The Bluestein algorithm uses a separate buffer which is + // convoluted with the input; the chain assumption isn't true here. + // NB: we assume that the CS_KERNEL_CHIRP is first in the chain. + continue; + } assert(childNodes[i - 1]->obOut == childNodes[i]->obIn); } } @@ -2289,7 +2296,9 @@ void TreeNode::TraverseTreeAssignParamsLogicA() } else { - assert((row1Plan->obOut == OB_USER_OUT) || (row1Plan->obOut == OB_TEMP_CMPLX_FOR_REAL) + // TODO: add documentation for assert. + assert((row1Plan->obOut == OB_USER_IN) || (row1Plan->obOut == OB_USER_OUT) + || (row1Plan->obOut == OB_TEMP_CMPLX_FOR_REAL) || (row1Plan->obOut == OB_TEMP_BLUESTEIN)); row1Plan->outStride.push_back(outStride[0]); @@ -2454,7 +2463,8 @@ void TreeNode::TraverseTreeAssignParamsLogicA() // here we don't have B info right away, we get it through its parent // TODO: what is this assert for? - assert(parent->obOut == OB_USER_OUT || parent->obOut == OB_TEMP_CMPLX_FOR_REAL + assert(parent->obOut == OB_USER_IN || parent->obOut == OB_USER_OUT + || parent->obOut == OB_TEMP_CMPLX_FOR_REAL || parent->scheme == CS_REAL_TRANSFORM_EVEN); // T-> B