You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Tma descriptor can be copied, but copied descriptor failed with illegal memory when trying to copy any data:
terminate called after throwing an instance of 'thrust::THRUST_200400_900_NS::system::system_error'
what(): CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)
Steps/Code to reproduce bug
I am working with tma copy in my code. When I tried to store a tma object in another object and copied it in constructor, I was facing illegal memory runtime error during copy.
Here is example of copy:
// ... Kernel code// auto& tmaLoad = params.tmaLoad; // If I take a reference, not copy, it works just fineauto tmaLoad = params.tmaLoad; // Copy tma load object. Compiles, no warnings// Create shared memory for copy and barrier
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem.data()), SmemLayout{});
auto &mbarrier = shared_storage.mbarrier;
// ...constint warp_idx = cutlass::canonical_warp_idx_sync();
constbool lane_predicate = cute::elect_one_sync();
constexprintkTmaTransactionBytes = sizeof(ArrayEngine<Element, size(SmemLayout{})>);
// .. auto cta_tmaS = tmaLoad.get_slice(Int<0>{});
Tensor mS = tmaLoad.get_tma_tensor(shape(gmemLayout));
auto blkCoord = make_coord(blockIdx.x, blockIdx.y);
Tensor gS = local_tile(mS, tileShape, blkCoord);
// Perform copyif (warp_idx == 0and lane_predicate) {
mbarrier.init(1/* arrive count */);
mbarrier.arrive_and_expect_tx(kTmaTransactionBytes);
copy(tmaLoad.with(reinterpret_cast<BarrierType &>(mbarrier)),
cta_tmaS.partition_S(gS), cta_tmaS.partition_D(sS)); // <-- Here happens runtime error
}
__syncthreads();
mbarrier.wait(0/* phase */);
cutlass::arch::fence_view_async_shared();
// ...
Full code that fails:
Spoiler
#include<cassert>
#include<cstdio>
#include<cstdlib>
#include<chrono>
#include<thrust/device_vector.h>
#include<thrust/host_vector.h>
#include<cute/layout.hpp>
#include<cutlass/numeric_types.h>
#include<cute/arch/cluster_sm90.hpp>
#include<cute/tensor.hpp>
#include<cutlass/arch/barrier.h>
#include<cutlass/cluster_launch.hpp>
#include<cutlass/cutlass.h>
#include<cutlass/util/GPU_Clock.hpp>
#include<cutlass/util/command_line.h>
#include<cutlass/util/helper_cuda.hpp>
#include<cutlass/util/print_error.hpp>
#include<cutlass/detail/layout.hpp>
#include<cutlass/util/command_line.h>usingnamespacecute;template <classElement, classSmemLayout> structSharedStorageTMA {
cute::array_aligned<Element, cute::cosize_v<SmemLayout>,
cutlass::detail::alignment_for_swizzle(SmemLayout{})>
smem;
cutlass::arch::ClusterTransactionBarrier mbarrier;
};
template <typename _TiledCopyS, typename _TiledCopyD, typename _GmemLayout,
typename _SmemLayout, typename _TileShape>
structParams {
using TiledCopyS = _TiledCopyS;
using TiledCopyD = _TiledCopyD;
using GmemLayout = _GmemLayout;
using SmemLayout = _SmemLayout;
using TileShape = _TileShape;
TiledCopyS const tmaLoad;
TiledCopyD const tmaStore;
GmemLayout const gmemLayout;
SmemLayout const smemLayout;
TileShape const tileShape;
Params(_TiledCopyS const &tmaLoad, _TiledCopyD const &tmaStore,
_GmemLayout const &gmemLayout, _SmemLayout const &smemLayout,
_TileShape const &tileShape)
: tmaLoad(tmaLoad), tmaStore(tmaStore), gmemLayout(gmemLayout),
smemLayout(smemLayout), tileShape(tileShape) {}
};
template <intkNumThreads, classElement, classParams>
__global__ staticvoid__launch_bounds__(kNumThreads, 1)
copyTMAKernel(CUTE_GRID_CONSTANT Params const params) {
usingnamespacecute;//// Get layouts and tiled copies from Params struct//using GmemLayout = typename Params::GmemLayout;
using SmemLayout = typename Params::SmemLayout;
using TileShape = typename Params::TileShape;
auto &gmemLayout = params.gmemLayout;
auto &smemLayout = params.smemLayout;
auto &tileShape = params.tileShape;
auto &tmaStore = params.tmaStore;
// auto &tmaLoad = params.tmaLoad;auto tmaLoad = params.tmaLoad; // Copy tma load object. Compiles, no warnings// Use Shared Storage structure to allocate aligned SMEM addresses.extern __shared__ char shared_memory[];
using SharedStorage = SharedStorageTMA<Element, SmemLayout>;
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(shared_memory);
// Define smem tensor
Tensor sS =
make_tensor(make_smem_ptr(shared_storage.smem.data()), smemLayout);
// Get mbarrier object and its value typeauto &mbarrier = shared_storage.mbarrier;
using BarrierType = cutlass::arch::ClusterTransactionBarrier::ValueType;
static_assert(cute::is_same_v<BarrierType, uint64_t>,
"Value type of mbarrier is uint64_t.");
// Constants used for TMAconstint warp_idx = cutlass::canonical_warp_idx_sync();
constbool lane_predicate = cute::elect_one_sync();
constexprintkTmaTransactionBytes =
sizeof(ArrayEngine<Element, size(SmemLayout{})>);
// Prefetch TMA descriptors for load and storeif (warp_idx == 0 && lane_predicate) {
prefetch_tma_descriptor(tmaLoad.get_tma_descriptor());
prefetch_tma_descriptor(tmaStore.get_tma_descriptor());
}
// Get CTA view of gmem tensor
Tensor mS = tmaLoad.get_tma_tensor(shape(gmemLayout));
auto blkCoord = make_coord(blockIdx.x, blockIdx.y);
Tensor gS = local_tile(mS, tileShape, blkCoord);
auto cta_tmaS = tmaLoad.get_slice(Int<0>{});
if (warp_idx == 0and lane_predicate) {
mbarrier.init(1/* arrive count */);
mbarrier.arrive_and_expect_tx(kTmaTransactionBytes);
copy(tmaLoad.with(reinterpret_cast<BarrierType &>(mbarrier)),
cta_tmaS.partition_S(gS), cta_tmaS.partition_D(sS));
}
__syncthreads();
mbarrier.wait(0/* phase */);
cutlass::arch::fence_view_async_shared();
// Get CTA view of gmem out tensorautomD = tmaStore.get_tma_tensor(shape(gmemLayout));
autogD = local_tile(mD, tileShape, blkCoord);
auto cta_tmaD = tmaStore.get_slice(Int<0>{});
if (warp_idx == 0and lane_predicate) {
cute::copy(tmaStore, cta_tmaD.partition_S(sS), cta_tmaD.partition_D(gD));
// cute::tma_store_arrive();
}
// cute::tma_store_wait<0>();
}
template <int TILE_M = 32, int TILE_N = 4>
intcopy_tma_example(int M, int N) {
using bM = Int<TILE_M>;
using bN = Int<TILE_N>;
using Element = float;
constexprint THREADS = TILE_M * TILE_N;
auto tensor_shape = make_shape(M, N);
// Allocate and initialize
thrust::host_vector<Element> h_S(size(tensor_shape)); // (M, N)
thrust::host_vector<Element> h_D(size(tensor_shape)); // (M, N)for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(float(i));
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
// Make tensorsauto gmemLayout = make_layout(tensor_shape, LayoutRight{});
Tensor tensor_S = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), gmemLayout);
Tensor tensor_D = make_tensor(
make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), gmemLayout);
auto tileShape = make_shape(bM{}, bN{});
auto smemLayout = make_layout(tileShape, LayoutRight{});
auto tma_load =
make_tma_copy(SM90_TMA_LOAD{}, tensor_S, smemLayout);
auto tma_store = make_tma_copy(SM90_TMA_STORE{}, tensor_D, smemLayout);
Params params(tma_load, tma_store, gmemLayout, smemLayout, tileShape);
dim3 gridDim(ceil_div(M, TILE_M), ceil_div(N, TILE_N));
dim3 blockDim(THREADS);
int smem_size = int(sizeof(SharedStorageTMA<Element, decltype(smemLayout)>));
voidconst *kernel =
(voidconst *)copyTMAKernel<THREADS, Element, decltype(params)>;
dim3 cluster_dims(1);
// Define the cluster launch parameter structure.
cutlass::ClusterLaunchParams launch_params{gridDim, blockDim, cluster_dims,
smem_size};
cutlass::Status status = cutlass::launch_kernel_on_cluster(launch_params, kernel, params);
cudaError result = cudaDeviceSynchronize();
return0;
}
intmain(int argc, charconst **argv) {
cutlass::CommandLine cmd(argc, argv);
int M = 16384;
int N = 16384;
returncopy_tma_example(M, N);
}
Expected behavior
I'm not familiar with details of how tma descriptor works under the hood, but that's either a bug in the constructor, or constructor should be deleted if it's unsafe to copy this object.
Environment details (please complete the following information):
I used latest main commit at the moment of opening issue (affd1b6)
I used this cmake file:
to compile and run with cuda 12.5 + gcc 10.3.0 like that:
mkdir build &&cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j8
./copy-tma
terminate called after throwing an instance of 'thrust::THRUST_200400_900_NS::system::system_error'what(): CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)
(NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5)
The text was updated successfully, but these errors were encountered:
Describe the bug
Tma descriptor can be copied, but copied descriptor failed with illegal memory when trying to copy any data:
Steps/Code to reproduce bug
I am working with tma copy in my code. When I tried to store a tma object in another object and copied it in constructor, I was facing illegal memory runtime error during copy.
Here is example of copy:
Full code that fails:
Spoiler
Expected behavior
I'm not familiar with details of how tma descriptor works under the hood, but that's either a bug in the constructor, or constructor should be deleted if it's unsafe to copy this object.
Environment details (please complete the following information):
I used latest main commit at the moment of opening issue (affd1b6)
I used this cmake file:
to compile and run with cuda 12.5 + gcc 10.3.0 like that:
(NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5)
The text was updated successfully, but these errors were encountered: