Skip to content

Commit

Permalink
Merge branch 'main' into ci_test_2024-09-26
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Jan 29, 2025
2 parents 1b96ae2 + 075f97f commit eca6e17
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions csrc/runtime/compiled_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,6 @@ std::vector<char> compileNvrtcProgramToPtx(const nvrtcProgram& program) {
std::unique_ptr<executor_utils::CudaExecutable> compileSource(
const std::string& full_src_code,
const std::string& func_name,
const std::string& kernel_name,
const bool compile_to_sass,
NvrtcCompileDriver& nvrtc_compile) {
std::stringstream log;
Expand All @@ -700,7 +699,7 @@ std::unique_ptr<executor_utils::CudaExecutable> compileSource(
NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&program));
});

createNvrtcProgram(program, kernel_name, full_src_code);
createNvrtcProgram(program, func_name, full_src_code);

NVFUSER_NVRTC_SAFE_CALL(nvrtcAddNameExpression(program, func_name.c_str()));
log << nvrtc_compile.invoke(program, full_src_code) << std::endl;
Expand All @@ -716,15 +715,15 @@ std::unique_ptr<executor_utils::CudaExecutable> compileSource(
compiled_kernel->cubin = compileNvrtcProgramToCubin(program);
if (isDebugDumpEnabled(DebugDumpOption::Cubin)) {
compiled_kernel->cubin_filename =
dumpCompiledCodeToFile(compiled_kernel->cubin, kernel_name, ".cubin");
dumpCompiledCodeToFile(compiled_kernel->cubin, func_name, ".cubin");
}
}

if (!compile_to_sass || isDebugDumpEnabled(DebugDumpOption::Ptx)) {
compiled_kernel->ptx = compileNvrtcProgramToPtx(program);
if (isDebugDumpEnabled(DebugDumpOption::Ptx)) {
compiled_kernel->ptx_filename =
dumpCompiledCodeToFile(compiled_kernel->ptx, kernel_name, ".ptx");
dumpCompiledCodeToFile(compiled_kernel->ptx, func_name, ".ptx");
}
}

Expand Down Expand Up @@ -810,11 +809,7 @@ std::unique_ptr<executor_utils::CudaExecutable> getCudaExecutable(
(compile_to_sass ? compiled_kernel->cubin
: compiled_kernel->ptx)))) {
compiled_kernel = compileSource(
full_src_code,
func_name,
compiled_kernel->kernel_name,
compile_to_sass,
nvrtc_compile_driver);
full_src_code, func_name, compile_to_sass, nvrtc_compile_driver);
log << compiled_kernel->compile_log << std::endl;
if (use_kernel_db) {
auto result = kernel_db.write(
Expand Down

0 comments on commit eca6e17

Please sign in to comment.