-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor memory support 1 #3755
Tensor memory support 1 #3755
Conversation
PR Reviewer Guide 🔍(Review updated until commit f9adf69)Here are some key observations to aid the review process:
|
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); | ||
} | ||
|
||
TEST_F(TMemTest, AddKernel) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4, Tensor<float, 1, 1> T9) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
nvfuser_index_t i0;
i0 = ((nvfuser_index_t)threadIdx.x) + (32LL * ((nvfuser_index_t)blockIdx.x));
bool b1;
b1 = i0 < T0.logical_size[0LL];
uint32_t* T10 = reinterpret_cast<uint32_t*>(array + smem_offset + 16LL);
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"((uint32_t)(toSmem(T10))), "n"(32U));
uint32_t* T11 = reinterpret_cast<uint32_t*>(array + smem_offset + 0LL);
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"((uint32_t)(toSmem(T11))), "n"(32U));
asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n");
__syncthreads();
float T1[1LL];
T1[0LL] = 0LL;
if (b1) {
T1[0LL]
= T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32LL * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
}
TMemTensor T2(T10[0LL], 0, 0);
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
:
:"r"((uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0})),
"f"((*reinterpret_cast<Array<float, 1, 1>*>(&T1[0LL]))[0])
);
asm volatile("tcgen05.wait::st.sync.aligned;\n");
float T3[1LL];
asm(
"tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
:"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0LL]))[0])
:"r"((uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0}))
);
asm volatile("tcgen05.wait::ld.sync.aligned;\n");
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(T10[0LL]), "n"(32U));
float T5[1LL];
T5[0LL] = 0LL;
if (b1) {
T5[0LL]
= T4[((T4.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32LL * T4.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
}
TMemTensor T6(T11[0LL], 0, 0);
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
:
:"r"((uint32_t)(T6 + Array<uint16_t, 2, 1>{0, 0})),
"f"((*reinterpret_cast<Array<float, 1, 1>*>(&T5[0LL]))[0])
);
asm volatile("tcgen05.wait::st.sync.aligned;\n");
float T7[1LL];
asm(
"tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
:"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T7[0LL]))[0])
:"r"((uint32_t)(T6 + Array<uint16_t, 2, 1>{0, 0}))
);
asm volatile("tcgen05.wait::ld.sync.aligned;\n");
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(T11[0LL]), "n"(32U));
float T8[1LL];
T8[0LL]
= T3[0LL]
+ T7[0LL];
if (b1) {
T9[i0]
= T8[0LL];
}
}
tests/cpp/test_memory.cpp
Outdated
// Tensor memory tests | ||
using TMemTest = NVFuserTest; | ||
|
||
TEST_F(TMemTest, GmemRegTMemRegGmemCopy) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
nvfuser_index_t i0;
i0 = ((nvfuser_index_t)threadIdx.x) + (32LL * ((nvfuser_index_t)blockIdx.x));
bool b1;
b1 = i0 < T0.logical_size[0LL];
uint32_t* T5 = reinterpret_cast<uint32_t*>(array + smem_offset + 0LL);
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"((uint32_t)(toSmem(T5))), "n"(32U));
asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n");
__syncthreads();
float T1[1LL];
T1[0LL] = 0LL;
if (b1) {
T1[0LL]
= T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32LL * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
}
TMemTensor T2(T5[0LL], 0, 0);
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
:
:"r"((uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0})),
"f"((*reinterpret_cast<Array<float, 1, 1>*>(&T1[0LL]))[0])
);
asm volatile("tcgen05.wait::st.sync.aligned;\n");
float T3[1LL];
asm(
"tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
:"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0LL]))[0])
:"r"((uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0}))
);
asm volatile("tcgen05.wait::ld.sync.aligned;\n");
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(T5[0LL]), "n"(32U));
if (b1) {
T4[i0]
= T3[0LL];
}
}
Minor change extracted from: #3755
Extracted from #3755 to make code review easy. This PR adds a new unit test `TMemTest.GmemRegTMemRegGmemCopy` that schedules a copy kernel gmem -> register -> tmem -> register -> gmem, and update our system with the minimum required changes to make this test pass. The purpose of this PR is not to provide a good implementation of TMem support, but just to provide the absolute minimal requirement for us to start. Limitations are: 1. The index is hard coded zero, so this PR is not touching the interesting topic of "how to schedule TMem tensor?" 2. The TMem is used without allocation. Using a memory that is not allocated is clearly a wrong way to program, but as described in the code comment, if a fusion only has one TMem TensorView, it is guaranteed to work. Generated code: ```CUDA __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4) { nvfuser_index_t i0; i0 = ((nvfuser_index_t)threadIdx.x) + (32 * ((nvfuser_index_t)blockIdx.x)); bool b1; b1 = i0 < T0.logical_size[0LL]; Array<float, 1, 1> T1; T1[0] = 0; if (b1) { T1[0] = T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))]; } asm volatile( "tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n" : :"r"(0U), "f"((*reinterpret_cast<Array<float, 1, 1>*>(&T1[0]))[0]) ); asm volatile("tcgen05.wait::st.sync.aligned;\n"); Array<float, 1, 1> T3; asm( "tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n" :"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0]))[0]) :"r"(0U) ); asm volatile("tcgen05.wait::ld.sync.aligned;\n"); if (b1) { T4[i0] = T3[0]; } } ```
Not ready for review