Skip to content

Commit

Permalink
Fix bug in Bluestein Algorithm and add tests (#220)
Browse files Browse the repository at this point in the history
* Add prime-size tests for 1D c2c transforms

* Fix Bluestein algorithm bug
  • Loading branch information
malcolmroberts authored Jul 23, 2019
1 parent 47a05c3 commit 5fe696b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
19 changes: 15 additions & 4 deletions clients/tests/accuracy_test_1D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class accuracy_test_complex_pow2_double : public ::testing::Test
#define MIX_RANGE \
6, 10, 12, 15, 20, 30, 120, 150, 225, 240, 300, 486, 600, 900, 1250, 1500, 1875, 2160, 2187, \
2250, 2500, 3000, 4000, 12000, 24000, 72000
#define PRIME_RANGE \
7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97

#define LARGE_RANGE \
4096, 4050, 4000, 3888, 3840, 3750, 3645, 3600, 3456, 3375, 3240, 3200, 3125, 3072, 3000, \
Expand All @@ -77,10 +79,11 @@ class accuracy_test_complex_pow2_double : public ::testing::Test
160, 150, 144, 135, 128, 125, 120, 108, 100, 96, 90, 81, 80, 75, 72, 64, 60, 54, 50, 48, \
45, 40, 36, 32, 30, 27, 25,\ 24, 20, 18, 16, 15, 12, 10, 9, 8, 6, 5, 4, 3, 2, 1

static std::vector<size_t> pow2_range = {POW2_RANGE};
static std::vector<size_t> pow3_range = {POW3_RANGE};
static std::vector<size_t> pow5_range = {POW5_RANGE};
static std::vector<size_t> mix_range = {MIX_RANGE};
static std::vector<size_t> pow2_range = {POW2_RANGE};
static std::vector<size_t> pow3_range = {POW3_RANGE};
static std::vector<size_t> pow5_range = {POW5_RANGE};
static std::vector<size_t> mix_range = {MIX_RANGE};
static std::vector<size_t> prime_range = {PRIME_RANGE};

static size_t batch_range[] = {1};

Expand Down Expand Up @@ -415,6 +418,14 @@ INSTANTIATE_TEST_CASE_P(rocfft_pow_random_1D,
ValuesIn(transform_range),
ValuesIn(stride_range)));

INSTANTIATE_TEST_CASE_P(rocfft_prime_1D,
accuracy_test_complex,
Combine(ValuesIn(prime_range),
ValuesIn(batch_range),
ValuesIn(placeness_range),
ValuesIn(transform_range),
ValuesIn(stride_range)));

// *****************************************************
// REAL HERMITIAN
// *****************************************************
Expand Down
18 changes: 14 additions & 4 deletions library/src/plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1672,11 +1672,11 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(OperatingBuffer& flipIn,
assert(childNodes.size() == 7);

assert(childNodes[0]->scheme == CS_KERNEL_CHIRP);
childNodes[0]->obIn = obIn;
childNodes[0]->obIn = OB_TEMP_BLUESTEIN;
childNodes[0]->obOut = OB_TEMP_BLUESTEIN;

assert(childNodes[1]->scheme == CS_KERNEL_PAD_MUL);
childNodes[1]->obIn = OB_TEMP_BLUESTEIN;
childNodes[1]->obIn = obIn;
childNodes[1]->obOut = OB_TEMP_BLUESTEIN;

childNodes[2]->obIn = OB_TEMP_BLUESTEIN;
Expand Down Expand Up @@ -1985,6 +1985,13 @@ void TreeNode::TraverseTreeAssignBuffersLogicA(OperatingBuffer& flipIn,
// Assert that the kernel chain is connected
for(int i = 1; i < childNodes.size(); ++i)
{
if(childNodes[i - 1]->scheme == CS_KERNEL_CHIRP)
{
// The Bluestein algorithm uses a separate buffer which is
// convoluted with the input; the chain assumption isn't true here.
// NB: we assume that the CS_KERNEL_CHIRP is first in the chain.
continue;
}
assert(childNodes[i - 1]->obOut == childNodes[i]->obIn);
}
}
Expand Down Expand Up @@ -2289,7 +2296,9 @@ void TreeNode::TraverseTreeAssignParamsLogicA()
}
else
{
assert((row1Plan->obOut == OB_USER_OUT) || (row1Plan->obOut == OB_TEMP_CMPLX_FOR_REAL)
// TODO: add documentation for assert.
assert((row1Plan->obOut == OB_USER_IN) || (row1Plan->obOut == OB_USER_OUT)
|| (row1Plan->obOut == OB_TEMP_CMPLX_FOR_REAL)
|| (row1Plan->obOut == OB_TEMP_BLUESTEIN));

row1Plan->outStride.push_back(outStride[0]);
Expand Down Expand Up @@ -2454,7 +2463,8 @@ void TreeNode::TraverseTreeAssignParamsLogicA()
// here we don't have B info right away, we get it through its parent

// TODO: what is this assert for?
assert(parent->obOut == OB_USER_OUT || parent->obOut == OB_TEMP_CMPLX_FOR_REAL
assert(parent->obOut == OB_USER_IN || parent->obOut == OB_USER_OUT
|| parent->obOut == OB_TEMP_CMPLX_FOR_REAL
|| parent->scheme == CS_REAL_TRANSFORM_EVEN);

// T-> B
Expand Down

0 comments on commit 5fe696b

Please sign in to comment.