Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naive register <-> tmem load/store support #3786

Merged
merged 7 commits into from
Jan 30, 2025
Merged

Naive register <-> tmem load/store support #3786

merged 7 commits into from
Jan 30, 2025

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Jan 29, 2025

Extracted from #3755 to make code review easy.

This PR adds a new unit test TMemTest.GmemRegTMemRegGmemCopy that schedules a copy kernel gmem -> register -> tmem -> register -> gmem, and update our system with the minimum required changes to make this test pass.

The purpose of this PR is not to provide a good implementation of TMem support, but just to provide the absolute minimal requirement for us to start. Limitations are:

  1. The index is hard coded zero, so this PR is not touching the interesting topic of "how to schedule TMem tensor?"
  2. The TMem is used without allocation. Using a memory that is not allocated is clearly a wrong way to program, but as described in the code comment, if a fusion only has one TMem TensorView, it is guaranteed to work.

Generated code:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (32 * ((nvfuser_index_t)blockIdx.x));
  bool b1;
  b1 = i0 < T0.logical_size[0LL];
  Array<float, 1, 1> T1;
  T1[0] = 0;
  if (b1) {
    T1[0]
       = T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
  }
  asm volatile(
    "tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
    :
    :"r"(0U),
     "f"((*reinterpret_cast<Array<float, 1, 1>*>(&T1[0]))[0])
  );
  asm volatile("tcgen05.wait::st.sync.aligned;\n");
  Array<float, 1, 1> T3;
  asm(
    "tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
    :"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0]))[0])
    :"r"(0U)
  );
  asm volatile("tcgen05.wait::ld.sync.aligned;\n");
  if (b1) {
    T4[i0]
       = T3[0];
  }
}

Copy link

github-actions bot commented Jan 29, 2025

PR Reviewer Guide 🔍

(Review updated until commit 1b1f4cc)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Missing Allocation

The current implementation does not allocate tensor memory, which may lead to issues with multiple CTAs accessing the same memory.

TensorMemoryInfo computeTMemInfo(Fusion* fusion) {
  bool found = false;
  for (auto tv : fusion->allTvs()) {
    if (tv->getMemoryType() == MemoryType::Tensor) {
      NVF_ERROR(!found, "Only one tensor on TMem is supported");
      found = true;
    }
  }
  return {};
}
Hardcoded Index

The index is hardcoded to zero, which may not be the intended behavior for all use cases.

// TODO: hard coded index zero for now.
auto index = IrBuilder::create<Val>(0, DataType::UInt32);
in = IrBuilder::create<kir::TensorIndex>(
Limited Support

The current implementation only supports 32-bit types in tensor memory, which may limit its usability.

  // TODO: support other types of ld/st
  auto ptx = "tcgen05.ld.sync.aligned.32x32b.x1.b32";
  registerReplace(
      ldst,
      IrBuilder::create<kir::Asm>(
          ptx,
          std::vector<Val*>{ldst->out()},
          std::vector<Val*>{ldst->in()}));
  auto wait_ptx = "tcgen05.wait::ld.sync.aligned";
  registerInsertAfter(
      ldst,
      IrBuilder::create<kir::Asm>(
          wait_ptx,
          std::vector<Val*>{},
          std::vector<Val*>{},
          kir::Asm::Options{/*volatile=*/true}));
} else if (ldst->opType() == LoadStoreOpType::StTMem) {
  // TODO: support other types of ld/st
  auto ptx = "tcgen05.st.sync.aligned.32x32b.x1.b32";
  registerReplace(
      ldst,
      IrBuilder::create<kir::Asm>(
          ptx,
          std::vector<Val*>{},
          std::vector<Val*>{ldst->out(), ldst->in()},
          kir::Asm::Options{/*volatile=*/true}));
  auto wait_ptx = "tcgen05.wait::st.sync.aligned";
  registerInsertAfter(
      ldst,
      IrBuilder::create<kir::Asm>(
          wait_ptx,
          std::vector<Val*>{},
          std::vector<Val*>{},
          kir::Asm::Options{/*volatile=*/true}));

@zasdfgbnm zasdfgbnm marked this pull request as ready for review January 29, 2025 06:20
@zasdfgbnm
Copy link
Collaborator Author

!test

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Just left one small question.

@zasdfgbnm zasdfgbnm merged commit 149c163 into main Jan 30, 2025
51 checks passed
@zasdfgbnm zasdfgbnm deleted the tmem-no-alloc branch January 30, 2025 08:23
Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to wrap the ptx assemble with cuda functions?

// different CTA to different physical address. There is no virtual TMem
// address. All addresses are physical addresses.
//
// Because multiple CTAs can execute on the same SM simultaneously, there must
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this handshaking mechanism, is it better to have only a single CTA occupy an SM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about kernel design for better perf? My guess is, if you allocate at the beginning of the kernel, and relinquish after allocate, the latency should be acceptable if you want to use multiple CTA on SM. But we need to test it before making any conclusion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for maximum performance.

@zasdfgbnm
Copy link
Collaborator Author

Do you plan to wrap the ptx assemble with cuda functions?

I like kir::Asm more than CUDA functions. Our generated cu file is already 10k+ lines of code, and I don't want to add more unless necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants