Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Copy constructor of tma descriptor produces a corrupted copy #2081

Open
pavlo-hilei opened this issue Feb 5, 2025 · 1 comment
Open
Labels
? - Needs Triage bug Something isn't working

Comments

@pavlo-hilei
Copy link

pavlo-hilei commented Feb 5, 2025

Describe the bug
Tma descriptor can be copied, but copied descriptor failed with illegal memory when trying to copy any data:

terminate called after throwing an instance of 'thrust::THRUST_200400_900_NS::system::system_error'
  what():  CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)

Steps/Code to reproduce bug
I am working with tma copy in my code. When I tried to store a tma object in another object and copied it in constructor, I was facing illegal memory runtime error during copy.

Here is example of copy:

// ... Kernel code
// auto& tmaLoad = params.tmaLoad;      // If I take a reference, not copy, it works just fine
auto tmaLoad = params.tmaLoad;      // Copy tma load object. Compiles, no warnings

// Create shared memory for copy and barrier
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem.data()), SmemLayout{});
auto &mbarrier = shared_storage.mbarrier;

// ...
const int warp_idx = cutlass::canonical_warp_idx_sync();
const bool lane_predicate = cute::elect_one_sync();
constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<Element, size(SmemLayout{})>);

// .. 
auto cta_tmaS = tmaLoad.get_slice(Int<0>{});
Tensor mS = tmaLoad.get_tma_tensor(shape(gmemLayout));
auto blkCoord = make_coord(blockIdx.x, blockIdx.y);
Tensor gS = local_tile(mS, tileShape, blkCoord);


// Perform copy
if (warp_idx == 0 and lane_predicate) {
  mbarrier.init(1 /* arrive count */);
  mbarrier.arrive_and_expect_tx(kTmaTransactionBytes);
  copy(tmaLoad.with(reinterpret_cast<BarrierType &>(mbarrier)),
       cta_tmaS.partition_S(gS), cta_tmaS.partition_D(sS));  // <-- Here happens runtime error
}
__syncthreads();

mbarrier.wait(0 /* phase */);

cutlass::arch::fence_view_async_shared();

// ... 

Full code that fails:

Spoiler
#include <cassert>
#include <cstdio>
#include <cstdlib>

#include <chrono>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include <cute/layout.hpp>
#include <cutlass/numeric_types.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/cutlass.h>

#include <cutlass/util/GPU_Clock.hpp>
#include <cutlass/util/command_line.h>
#include <cutlass/util/helper_cuda.hpp>
#include <cutlass/util/print_error.hpp>

#include <cutlass/detail/layout.hpp>
#include <cutlass/util/command_line.h>

using namespace cute;

template <class Element, class SmemLayout> struct SharedStorageTMA {
cute::array_aligned<Element, cute::cosize_v<SmemLayout>,
                    cutlass::detail::alignment_for_swizzle(SmemLayout{})>
    smem;
cutlass::arch::ClusterTransactionBarrier mbarrier;
};


template <typename _TiledCopyS, typename _TiledCopyD, typename _GmemLayout,
        typename _SmemLayout, typename _TileShape>
struct Params {
using TiledCopyS = _TiledCopyS;
using TiledCopyD = _TiledCopyD;
using GmemLayout = _GmemLayout;
using SmemLayout = _SmemLayout;
using TileShape = _TileShape;

TiledCopyS const tmaLoad;
TiledCopyD const tmaStore;
GmemLayout const gmemLayout;
SmemLayout const smemLayout;
TileShape const tileShape;

Params(_TiledCopyS const &tmaLoad, _TiledCopyD const &tmaStore,
       _GmemLayout const &gmemLayout, _SmemLayout const &smemLayout,
       _TileShape const &tileShape)
    : tmaLoad(tmaLoad), tmaStore(tmaStore), gmemLayout(gmemLayout),
      smemLayout(smemLayout), tileShape(tileShape) {}
};

template <int kNumThreads, class Element, class Params>
__global__ static void __launch_bounds__(kNumThreads, 1)
  copyTMAKernel(CUTE_GRID_CONSTANT Params const params) {
using namespace cute;

//
// Get layouts and tiled copies from Params struct
//
using GmemLayout = typename Params::GmemLayout;
using SmemLayout = typename Params::SmemLayout;
using TileShape = typename Params::TileShape;

auto &gmemLayout = params.gmemLayout;
auto &smemLayout = params.smemLayout;
auto &tileShape = params.tileShape;

auto &tmaStore = params.tmaStore;
// auto &tmaLoad = params.tmaLoad;
auto tmaLoad = params.tmaLoad;      // Copy tma load object. Compiles, no warnings


// Use Shared Storage structure to allocate aligned SMEM addresses.
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorageTMA<Element, SmemLayout>;
SharedStorage &shared_storage =
    *reinterpret_cast<SharedStorage *>(shared_memory);

// Define smem tensor
Tensor sS =
    make_tensor(make_smem_ptr(shared_storage.smem.data()), smemLayout);

// Get mbarrier object and its value type
auto &mbarrier = shared_storage.mbarrier;
using BarrierType = cutlass::arch::ClusterTransactionBarrier::ValueType;
static_assert(cute::is_same_v<BarrierType, uint64_t>,
              "Value type of mbarrier is uint64_t.");

// Constants used for TMA
const int warp_idx = cutlass::canonical_warp_idx_sync();
const bool lane_predicate = cute::elect_one_sync();
constexpr int kTmaTransactionBytes =
    sizeof(ArrayEngine<Element, size(SmemLayout{})>);

// Prefetch TMA descriptors for load and store
if (warp_idx == 0 && lane_predicate) {
  prefetch_tma_descriptor(tmaLoad.get_tma_descriptor());
  prefetch_tma_descriptor(tmaStore.get_tma_descriptor());
}

// Get CTA view of gmem tensor
Tensor mS = tmaLoad.get_tma_tensor(shape(gmemLayout));
auto blkCoord = make_coord(blockIdx.x, blockIdx.y);
Tensor gS = local_tile(mS, tileShape, blkCoord);

auto cta_tmaS = tmaLoad.get_slice(Int<0>{});

if (warp_idx == 0 and lane_predicate) {
  mbarrier.init(1 /* arrive count */);
  mbarrier.arrive_and_expect_tx(kTmaTransactionBytes);
  copy(tmaLoad.with(reinterpret_cast<BarrierType &>(mbarrier)),
       cta_tmaS.partition_S(gS), cta_tmaS.partition_D(sS));
}
__syncthreads();

mbarrier.wait(0 /* phase */);

cutlass::arch::fence_view_async_shared();

// Get CTA view of gmem out tensor
auto mD = tmaStore.get_tma_tensor(shape(gmemLayout));
auto gD = local_tile(mD, tileShape, blkCoord);

auto cta_tmaD = tmaStore.get_slice(Int<0>{});

if (warp_idx == 0 and lane_predicate) {
  cute::copy(tmaStore, cta_tmaD.partition_S(sS), cta_tmaD.partition_D(gD));
  // cute::tma_store_arrive();
}
// cute::tma_store_wait<0>();
}


template <int TILE_M = 32, int TILE_N = 4>
int copy_tma_example(int M, int N) {
using bM = Int<TILE_M>;
using bN = Int<TILE_N>;
using Element = float;
constexpr int THREADS = TILE_M * TILE_N;

auto tensor_shape = make_shape(M, N);

// Allocate and initialize
thrust::host_vector<Element> h_S(size(tensor_shape)); // (M, N)
thrust::host_vector<Element> h_D(size(tensor_shape)); // (M, N)

for (size_t i = 0; i < h_S.size(); ++i) {
  h_S[i] = static_cast<Element>(float(i));
}

thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;

// Make tensors
auto gmemLayout = make_layout(tensor_shape, LayoutRight{});
Tensor tensor_S = make_tensor(
    make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), gmemLayout);
Tensor tensor_D = make_tensor(
    make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayout);


auto tileShape = make_shape(bM{}, bN{});
auto smemLayout = make_layout(tileShape, LayoutRight{});
auto tma_load =
    make_tma_copy(SM90_TMA_LOAD{}, tensor_S, smemLayout);

auto tma_store = make_tma_copy(SM90_TMA_STORE{}, tensor_D, smemLayout);

Params params(tma_load, tma_store, gmemLayout, smemLayout, tileShape);

dim3 gridDim(ceil_div(M, TILE_M), ceil_div(N, TILE_N));
dim3 blockDim(THREADS);

int smem_size = int(sizeof(SharedStorageTMA<Element, decltype(smemLayout)>));
void const *kernel =
    (void const *)copyTMAKernel<THREADS, Element, decltype(params)>;

dim3 cluster_dims(1);

// Define the cluster launch parameter structure.
cutlass::ClusterLaunchParams launch_params{gridDim, blockDim, cluster_dims,
                                           smem_size};

cutlass::Status status = cutlass::launch_kernel_on_cluster(launch_params, kernel, params);
cudaError result = cudaDeviceSynchronize();
return 0;
}


int main(int argc, char const **argv) {
cutlass::CommandLine cmd(argc, argv);
int M = 16384;
int N = 16384;
return copy_tma_example(M, N);
}

Expected behavior
I'm not familiar with details of how tma descriptor works under the hood, but that's either a bug in the constructor, or constructor should be deleted if it's unsafe to copy this object.

Environment details (please complete the following information):
I used latest main commit at the moment of opening issue (affd1b6)
I used this cmake file:

cmake_minimum_required(VERSION 3.12 FATAL_ERROR)

project(tma-example LANGUAGES CUDA)

add_executable(copy-tma copy_tma.cu)
set_target_properties(copy-tma PROPERTIES CUDA_ARCHITECTURES "90")
target_include_directories(copy-tma PRIVATE
    cutlass/include
    cutlass/util/include
)
target_compile_options(copy-tma PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-std=c++17 -O3 -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing --expt-relaxed-constexpr -lineinfo --ptxas-options=-v>)

to compile and run with cuda 12.5 + gcc 10.3.0 like that:

mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=Release .. 
make -j8
./copy-tma
terminate called after throwing an instance of 'thrust::THRUST_200400_900_NS::system::system_error'
  what():  CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)

(NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5)

@pavlo-hilei pavlo-hilei added ? - Needs Triage bug Something isn't working labels Feb 5, 2025
@hwu36
Copy link
Collaborator

hwu36 commented Feb 6, 2025

@thakkarV @ANIKET-SHIVAM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants