From 2cf9724507840e7ec1c1c797940b7752506219e2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 22 Oct 2025 21:42:14 +0200 Subject: [PATCH 01/77] register everything Signed-off-by: Ivan Butygin --- water/tools/water-opt/CMakeLists.txt | 4 ++++ water/tools/water-opt/water-opt.cpp | 15 ++++----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/water/tools/water-opt/CMakeLists.txt b/water/tools/water-opt/CMakeLists.txt index 1af3103c0..c8c726b89 100644 --- a/water/tools/water-opt/CMakeLists.txt +++ b/water/tools/water-opt/CMakeLists.txt @@ -12,6 +12,10 @@ set(LIBS MLIRWaterTestTransforms MLIRWaterTestDialect + + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) add_llvm_executable(water-opt diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index edb7ae681..27e381f02 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" @@ -40,19 +41,11 @@ int main(int argc, char **argv) { mlir::water::registerPasses(); mlir::water::test::registerAllPasses(); wave::registerPasses(); - mlir::arith::registerArithIntRangeOptsPass(); - mlir::registerCanonicalizerPass(); - mlir::registerCSEPass(); - mlir::registerLoopInvariantCodeMotionPass(); - mlir::registerLowerAffinePass(); + mlir::registerAllPasses(); mlir::DialectRegistry registry; - registry.insert(); + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); mlir::water::test::registerWaterTestDialect(registry); From a63e1398412150dd9c45bd69ac911f671094f654 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 22 Oct 2025 21:42:36 +0200 Subject: [PATCH 02/77] initial pipeline Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 70 ++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 3851f1a19..ea0433dab 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -191,8 +191,40 @@ def is_water_available() -> bool: return False +def get_water_binary_path() -> str: + return find_binary("water-opt") + + +def make_linear_pass_pipeline( + pipeline: Sequence[ + tuple[str, dict[str, Any]] | tuple[str, dict[str, Any], str] | str + ], +) -> str: + def make_pass_arguments( + name: str, args: dict[str, Any], module_name: str = None + ) -> str: + ret = ( + name + + "{" + + " ".join("=".join((key, str(value))) for (key, value) in args.items()) + + "}" + ) + if module_name: + ret = module_name + "(" + ret + ")" + return ret + + return ( + "--pass-pipeline=builtin.module(" + + ",".join( + entry if isinstance(entry, str) else make_pass_arguments(*entry) + for entry in pipeline + ) + + ")" + ) + + def water_leak_in_bounds_check(module: Module, override_ir: str = ""): - binary = find_binary("water-opt") + binary = get_water_binary_path() generic_mlir = _deiree(module) if override_ir == "" else override_ir pipeline = [ ( @@ -208,26 +240,6 @@ def water_leak_in_bounds_check(module: Module, override_ir: str = ""): "water-check-static-assertions", ] - def make_linear_pass_pipeline( - pipeline: Sequence[tuple[str, dict[str, Any]] | str], - ) -> str: - def make_pass_arguments(name: str, args: dict[str, Any]) -> str: - return ( - name - + "{" - + " ".join("=".join((key, str(value))) for (key, value) in args.items()) - + "}" - ) - - return ( - "--pass-pipeline=builtin.module(" - + ",".join( - entry if isinstance(entry, str) else make_pass_arguments(*entry) - for entry in pipeline - ) - + ")" - ) - def get_code_context( filename: str, start_line: int, end_line, context: int = 2 ) -> str: @@ -364,3 +376,19 @@ def diagnostic_from_json( ) else: print("[info] No out-of-bounds accesses detected.") + + +def water_lowering_pipeline(module: Module) -> Module: + binary = get_water_binary_path() + mlir_asm = module.operation.get_asm() + pipeline = [ + ("convert-gpu-to-rocdl", {"use-bare-ptr-memref-call-conv": "1"}, "gpu.module"), + ("rocdl-attach-target", {"chip": "gfx1100"}, "gpu.module"), + ] + result = subprocess.check_output( + [binary, make_linear_pass_pipeline(pipeline)], + input=mlir_asm, + text=True, + stderr=subprocess.STDOUT, + ) + print(result) From fc8ed5e23debf5d09642bcdb3dcb5238adc943dd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 22 Oct 2025 22:51:33 +0200 Subject: [PATCH 03/77] register llvm translation Signed-off-by: Ivan Butygin --- water/tools/water-opt/CMakeLists.txt | 1 + water/tools/water-opt/water-opt.cpp | 3 +++ 2 files changed, 4 insertions(+) diff --git a/water/tools/water-opt/CMakeLists.txt b/water/tools/water-opt/CMakeLists.txt index c8c726b89..d5c9b0a21 100644 --- a/water/tools/water-opt/CMakeLists.txt +++ b/water/tools/water-opt/CMakeLists.txt @@ -16,6 +16,7 @@ set(LIBS MLIRRegisterAllDialects MLIRRegisterAllExtensions MLIRRegisterAllPasses + MLIRToLLVMIRTranslationRegistration ) add_llvm_executable(water-opt diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index 27e381f02..125349387 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -22,6 +22,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" @@ -47,6 +48,8 @@ int main(int argc, char **argv) { mlir::registerAllDialects(registry); mlir::registerAllExtensions(registry); + mlir::registerAllGPUToLLVMIRTranslations(registry); + mlir::water::test::registerWaterTestDialect(registry); return mlir::asMainReturnCode( From c7dd5161749f1b2514bf6c37b4e011e0d0cb5a16 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 22 Oct 2025 22:51:59 +0200 Subject: [PATCH 04/77] more pipeline Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index ea0433dab..33e9180dc 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -378,17 +378,30 @@ def diagnostic_from_json( print("[info] No out-of-bounds accesses detected.") -def water_lowering_pipeline(module: Module) -> Module: +def water_lowering_pipeline(module: Module, target_chip: str) -> Module: binary = get_water_binary_path() mlir_asm = module.operation.get_asm() pipeline = [ + "lower-affine", + "canonicalize", + "cse", + "loop-invariant-code-motion", + "int-range-optimizations", ("convert-gpu-to-rocdl", {"use-bare-ptr-memref-call-conv": "1"}, "gpu.module"), - ("rocdl-attach-target", {"chip": "gfx1100"}, "gpu.module"), + ("rocdl-attach-target", {"chip": target_chip}), + ("gpu-to-llvm", {"use-bare-pointers-for-kernels": "1"}), + "reconcile-unrealized-casts", + "canonicalize", + "cse", + "gpu-module-to-binary", ] - result = subprocess.check_output( - [binary, make_linear_pass_pipeline(pipeline)], - input=mlir_asm, - text=True, - stderr=subprocess.STDOUT, - ) + try: + result = subprocess.check_output( + [binary, make_linear_pass_pipeline(pipeline)], + input=mlir_asm, + text=True, + ) + except subprocess.CalledProcessError as e: + print(e.stderr) + raise e print(result) From 3f594a86a4e98765f421b9a0a1a1bf0c3fa0d463 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 23 Oct 2025 19:16:24 +0200 Subject: [PATCH 05/77] return Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 33e9180dc..4e53096a5 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -404,4 +404,6 @@ def water_lowering_pipeline(module: Module, target_chip: str) -> Module: except subprocess.CalledProcessError as e: print(e.stderr) raise e - print(result) + + with module.context: + return Module.parse(result) From 9c6369cb3ce83199d9d0fe47765f61ef16b8764b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 23 Oct 2025 23:34:44 +0200 Subject: [PATCH 06/77] gpu-to-gpu-runtime Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 4e53096a5..f7ccef482 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -394,6 +394,10 @@ def water_lowering_pipeline(module: Module, target_chip: str) -> Module: "canonicalize", "cse", "gpu-module-to-binary", + "water-gpu-to-gpu-runtime", + "symbol-dce", + "canonicalize", + "cse", ] try: result = subprocess.check_output( From f7433f53c53c179741ed0032d88defdb7bf22409 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 9 Nov 2025 22:10:04 +0100 Subject: [PATCH 07/77] WIP: Add LLVM ExecutionEngine stub MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Create initial stub for LLVM ExecutionEngine in wave_lang/kernel/wave/execution_engine/ with nanobind bindings and build configuration. The implementation is not yet complete. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Signed-off-by: Ivan Butygin --- setup.py | 6 + .../wave/execution_engine/CMakeLists.txt | 99 ++++++++++++ .../kernel/wave/execution_engine/bindings.cpp | 80 ++++++++++ .../execution_engine/execution_engine.cpp | 142 ++++++++++++++++++ .../wave/execution_engine/execution_engine.h | 96 ++++++++++++ 5 files changed, 423 insertions(+) create mode 100644 wave_lang/kernel/wave/execution_engine/CMakeLists.txt create mode 100644 wave_lang/kernel/wave/execution_engine/bindings.cpp create mode 100644 wave_lang/kernel/wave/execution_engine/execution_engine.cpp create mode 100644 wave_lang/kernel/wave/execution_engine/execution_engine.h diff --git a/setup.py b/setup.py index b178c1783..597dbf465 100644 --- a/setup.py +++ b/setup.py @@ -282,6 +282,12 @@ def initialize_options(self): if BUILD_WATER: ext_modules += [ + CMakeExtension( + "wave_execution_engine", + "wave_lang/kernel/wave/execution_engine", + install_dir="wave_lang/kernel/wave/execution_engine", + need_llvm=True, + ), CMakeExtension( "water", "water", diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt new file mode 100644 index 000000000..56a77dcb3 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -0,0 +1,99 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +cmake_minimum_required(VERSION 3.19...3.27) +project(wave_execution_engine) + +# Skip building on macOS +if(APPLE) + message(STATUS "Skipping wave_execution_engine build on ${CMAKE_SYSTEM_NAME}") + return() +endif() + +# Set the C++ standard +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(Python 3.10 COMPONENTS Interpreter Development.Module REQUIRED) + +if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +# Detect the installed nanobind package and import it into CMake +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT + COMMAND_ERROR_IS_FATAL ANY) +find_package(nanobind CONFIG REQUIRED) + +# Build the core parts of nanobind once +nanobind_build_library(nanobind SHARED) + +# TODO: Add LLVM and MLIR dependencies when implementing +# find_package(LLVM REQUIRED CONFIG) +# find_package(MLIR REQUIRED CONFIG) +# +# message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +# message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +# message(STATUS "Found MLIR in: ${MLIR_DIR}") +# +# include_directories(${LLVM_INCLUDE_DIRS}) +# include_directories(${MLIR_INCLUDE_DIRS}) +# +# add_definitions(${LLVM_DEFINITIONS}) + +# Compile an extension library +add_library(wave_execution_engine MODULE + execution_engine.cpp + bindings.cpp +) + +# Include current directory for header files +target_include_directories(wave_execution_engine PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + +# Link against nanobind +target_link_libraries(wave_execution_engine PRIVATE nanobind) + +# TODO: Link against LLVM and MLIR libraries when implementing +# target_link_libraries(wave_execution_engine PRIVATE +# MLIRIR +# MLIRParser +# MLIRExecutionEngine +# MLIRTargetLLVMIRExport +# MLIRLLVMDialect +# LLVMCore +# LLVMSupport +# LLVMExecutionEngine +# LLVMJIT +# LLVMOrcJIT +# LLVMRuntimeDyld +# ) + +set_target_properties(wave_execution_engine PROPERTIES LINK_WHAT_YOU_USE TRUE) + +# Enable size optimizations +nanobind_opt_size(wave_execution_engine) + +# Enable link time optimization +nanobind_lto(wave_execution_engine) + +# Set the default symbol visibility to 'hidden' +nanobind_set_visibility(wave_execution_engine) + +# Strip unneeded symbols and debug info from the binary (only active in release builds) +nanobind_strip(wave_execution_engine) + +# Disable the stack protector +nanobind_disable_stack_protector(wave_execution_engine) + +# Set the Python extension suffix +nanobind_extension(wave_execution_engine) + +# Set important linker flags +nanobind_link_options(wave_execution_engine) + +install(TARGETS wave_execution_engine DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp new file mode 100644 index 000000000..53b487a86 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -0,0 +1,80 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "execution_engine.h" +#include +#include +#include + +namespace nb = nanobind; + +// Nanobind module definition for Python bindings +NB_MODULE(wave_execution_engine, m) { + m.doc() = "LLVM ExecutionEngine bindings for Wave JIT compilation"; + + // Bind the WaveExecutionEngine class + nb::class_(m, "ExecutionEngine") + .def(nb::init<>(), + "Create a new WaveExecutionEngine instance") + .def("initialize", &wave::WaveExecutionEngine::initialize, + nb::arg("mlir_module_str"), + "Initialize the execution engine with an MLIR module string.\n\n" + "Args:\n" + " mlir_module_str: MLIR module as a string\n\n" + "Raises:\n" + " RuntimeError: If initialization fails or already initialized") + .def("load_llvm_ir", &wave::WaveExecutionEngine::load_llvm_ir, + nb::arg("ir_str"), + "Load a pre-compiled LLVM IR module.\n\n" + "Args:\n" + " ir_str: LLVM IR as a string\n\n" + "Raises:\n" + " RuntimeError: If loading fails") + .def("invoke", &wave::WaveExecutionEngine::invoke, + nb::arg("func_name"), nb::arg("args"), + "Invoke a function by name with the given arguments.\n\n" + "Args:\n" + " func_name: Name of the function to invoke\n" + " args: List of arguments as uint64_t values\n\n" + "Raises:\n" + " RuntimeError: If engine not initialized or function not found") + .def("get_function_address", &wave::WaveExecutionEngine::get_function_address, + nb::arg("func_name"), + "Get the address of a function by name.\n\n" + "Args:\n" + " func_name: Name of the function\n\n" + "Returns:\n" + " Address of the function as an integer\n\n" + "Raises:\n" + " RuntimeError: If engine not initialized or function not found") + .def("is_initialized", &wave::WaveExecutionEngine::is_initialized, + "Check if the execution engine is initialized.\n\n" + "Returns:\n" + " True if initialized, False otherwise") + .def("optimize", &wave::WaveExecutionEngine::optimize, + nb::arg("opt_level") = 2, + "Optimize the module with the given optimization level.\n\n" + "Args:\n" + " opt_level: Optimization level (0-3, default 2)\n" + " 0 = No optimization\n" + " 1 = Basic optimizations\n" + " 2 = Standard optimizations (default)\n" + " 3 = Aggressive optimizations\n\n" + "Raises:\n" + " RuntimeError: If engine not initialized or invalid opt_level") + .def("dump_llvm_ir", &wave::WaveExecutionEngine::dump_llvm_ir, + "Dump the current LLVM IR module as a string.\n\n" + "Returns:\n" + " LLVM IR as a string\n\n" + "Raises:\n" + " RuntimeError: If engine not initialized"); + + // Bind the global initialization function + m.def("initialize_llvm_mlir", &wave::initialize_llvm_mlir, + "Initialize LLVM and MLIR infrastructure.\n\n" + "Must be called before creating any WaveExecutionEngine instances.\n\n" + "Raises:\n" + " RuntimeError: If initialization fails"); +} diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp new file mode 100644 index 000000000..c3727b68e --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -0,0 +1,142 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "execution_engine.h" +#include + +namespace wave { + +WaveExecutionEngine::WaveExecutionEngine() : initialized_(false) {} + +WaveExecutionEngine::~WaveExecutionEngine() { + if (initialized_) { + cleanup(); + } +} + +void WaveExecutionEngine::initialize(const std::string& mlir_module_str) { + if (initialized_) { + throw std::runtime_error("ExecutionEngine already initialized"); + } + + // TODO: Implement MLIR module parsing and LLVM ExecutionEngine creation + // Steps: + // 1. Parse MLIR module from string + // 2. Convert MLIR to LLVM IR + // 3. Create LLVM ExecutionEngine + // 4. JIT compile the module + + throw std::runtime_error("ExecutionEngine initialization not yet implemented"); +} + +void WaveExecutionEngine::load_llvm_ir(const std::string& ir_str) { + // TODO: Implement LLVM IR loading + // Steps: + // 1. Parse LLVM IR string + // 2. Create LLVM module from IR + // 3. Set up ExecutionEngine with the module + // 4. Finalize the module for execution + + throw std::runtime_error("LLVM IR loading not yet implemented"); +} + +void WaveExecutionEngine::invoke(const std::string& func_name, + const std::vector& args) { + if (!initialized_) { + throw std::runtime_error("ExecutionEngine not initialized"); + } + + // TODO: Implement function lookup and invocation + // Steps: + // 1. Look up function by name in ExecutionEngine + // 2. Get function pointer + // 3. Marshal arguments based on function signature + // 4. Invoke function with marshalled arguments + // 5. Return results (if any) + + throw std::runtime_error("Function invocation not yet implemented"); +} + +uintptr_t WaveExecutionEngine::get_function_address(const std::string& func_name) { + if (!initialized_) { + throw std::runtime_error("ExecutionEngine not initialized"); + } + + // TODO: Implement function address lookup + // Steps: + // 1. Look up function by name in ExecutionEngine + // 2. Get function address using ExecutionEngine::getFunctionAddress() + // 3. Return the address as uintptr_t + + throw std::runtime_error("Function address lookup not yet implemented"); +} + +bool WaveExecutionEngine::is_initialized() const { + return initialized_; +} + +void WaveExecutionEngine::optimize(int opt_level) { + if (!initialized_) { + throw std::runtime_error("ExecutionEngine not initialized"); + } + + if (opt_level < 0 || opt_level > 3) { + throw std::runtime_error("Invalid optimization level. Must be 0-3"); + } + + // TODO: Implement module optimization + // Steps: + // 1. Create LLVM PassManager + // 2. Add optimization passes based on opt_level: + // - O0: No optimization + // - O1: Basic optimizations + // - O2: Standard optimizations (default) + // - O3: Aggressive optimizations + // 3. Run passes on the module + + throw std::runtime_error("Module optimization not yet implemented"); +} + +std::string WaveExecutionEngine::dump_llvm_ir() const { + if (!initialized_) { + throw std::runtime_error("ExecutionEngine not initialized"); + } + + // TODO: Implement LLVM IR dumping + // Steps: + // 1. Get the LLVM module from ExecutionEngine + // 2. Convert module to string using raw_string_ostream + // 3. Return the string representation + + return "LLVM IR dump not yet implemented"; +} + +void WaveExecutionEngine::cleanup() { + // TODO: Implement cleanup of LLVM resources + // Steps: + // 1. Clean up ExecutionEngine + // 2. Clean up LLVM module + // 3. Clean up LLVM context + // 4. Clean up MLIR context (if used) + + initialized_ = false; +} + +void initialize_llvm_mlir() { + // TODO: Implement LLVM and MLIR initialization + // Steps: + // 1. Initialize LLVM targets: + // - InitializeNativeTarget() + // - InitializeNativeTargetAsmPrinter() + // - InitializeNativeTargetAsmParser() + // 2. Register MLIR dialects (if needed): + // - registerAllDialects() + // - registerLLVMDialectTranslation() + // 3. Initialize LLVM passes (if needed) + + throw std::runtime_error("LLVM/MLIR initialization not yet implemented"); +} + +} // namespace wave diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h new file mode 100644 index 000000000..0905e4b27 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -0,0 +1,96 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef WAVE_EXECUTION_ENGINE_H +#define WAVE_EXECUTION_ENGINE_H + +#include +#include +#include +#include + +// Forward declarations for LLVM types (to be implemented) +// These will be replaced with actual LLVM includes when implementing +namespace llvm { + class Module; + class ExecutionEngine; + class LLVMContext; +} + +namespace mlir { + class MLIRContext; + class ModuleOp; +} + +namespace wave { + +/// LLVM ExecutionEngine wrapper for Wave JIT compilation +class WaveExecutionEngine { +public: + WaveExecutionEngine(); + ~WaveExecutionEngine(); + + // Disable copy and move operations + WaveExecutionEngine(const WaveExecutionEngine&) = delete; + WaveExecutionEngine& operator=(const WaveExecutionEngine&) = delete; + WaveExecutionEngine(WaveExecutionEngine&&) = delete; + WaveExecutionEngine& operator=(WaveExecutionEngine&&) = delete; + + /// Initialize the execution engine with MLIR module + /// @param mlir_module_str MLIR module as a string + /// @throws std::runtime_error if initialization fails or already initialized + void initialize(const std::string& mlir_module_str); + + /// Load a pre-compiled LLVM IR module + /// @param ir_str LLVM IR as a string + /// @throws std::runtime_error if loading fails + void load_llvm_ir(const std::string& ir_str); + + /// Lookup and invoke a function by name + /// @param func_name Name of the function to invoke + /// @param args Vector of arguments as uint64_t values + /// @throws std::runtime_error if engine not initialized or function not found + void invoke(const std::string& func_name, const std::vector& args); + + /// Get pointer to a function by name + /// @param func_name Name of the function + /// @return Address of the function + /// @throws std::runtime_error if engine not initialized or function not found + uintptr_t get_function_address(const std::string& func_name); + + /// Check if engine is initialized + /// @return true if initialized, false otherwise + bool is_initialized() const; + + /// Optimize the module with given optimization level + /// @param opt_level Optimization level (0-3, default 2) + /// @throws std::runtime_error if engine not initialized + void optimize(int opt_level = 2); + + /// Dump the current LLVM IR module for debugging + /// @return LLVM IR as a string + /// @throws std::runtime_error if engine not initialized + std::string dump_llvm_ir() const; + +private: + void cleanup(); + + bool initialized_; + + // TODO: Add private members for: + // std::unique_ptr llvm_context_; + // std::unique_ptr llvm_module_; + // std::unique_ptr execution_engine_; + // std::unique_ptr mlir_context_; +}; + +/// Initialize LLVM and MLIR infrastructure +/// Must be called before creating any WaveExecutionEngine instances +/// @throws std::runtime_error if initialization fails +void initialize_llvm_mlir(); + +} // namespace wave + +#endif // WAVE_EXECUTION_ENGINE_H From 1cf001062a336be5c41ded54a33ef88562f11a46 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 9 Nov 2025 22:34:13 +0100 Subject: [PATCH 08/77] ExecutionEngine WIP Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 68 ++- .../kernel/wave/execution_engine/bindings.cpp | 147 +++--- .../execution_engine/execution_engine.cpp | 447 ++++++++++++++---- .../wave/execution_engine/execution_engine.h | 171 ++++--- 4 files changed, 573 insertions(+), 260 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 56a77dcb3..a56c1a369 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -33,18 +33,19 @@ find_package(nanobind CONFIG REQUIRED) # Build the core parts of nanobind once nanobind_build_library(nanobind SHARED) -# TODO: Add LLVM and MLIR dependencies when implementing -# find_package(LLVM REQUIRED CONFIG) -# find_package(MLIR REQUIRED CONFIG) -# -# message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") -# message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -# message(STATUS "Found MLIR in: ${MLIR_DIR}") -# -# include_directories(${LLVM_INCLUDE_DIRS}) -# include_directories(${MLIR_INCLUDE_DIRS}) -# -# add_definitions(${LLVM_DEFINITIONS}) +# Find LLVM and MLIR +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +message(STATUS "Found MLIR in: ${MLIR_DIR}") + +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) + +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) +add_definitions(${LLVM_DEFINITIONS_LIST}) # Compile an extension library add_library(wave_execution_engine MODULE @@ -58,20 +59,35 @@ target_include_directories(wave_execution_engine PRIVATE ${CMAKE_CURRENT_SOURCE_ # Link against nanobind target_link_libraries(wave_execution_engine PRIVATE nanobind) -# TODO: Link against LLVM and MLIR libraries when implementing -# target_link_libraries(wave_execution_engine PRIVATE -# MLIRIR -# MLIRParser -# MLIRExecutionEngine -# MLIRTargetLLVMIRExport -# MLIRLLVMDialect -# LLVMCore -# LLVMSupport -# LLVMExecutionEngine -# LLVMJIT -# LLVMOrcJIT -# LLVMRuntimeDyld -# ) +# Link against LLVM and MLIR libraries +target_link_libraries(wave_execution_engine PRIVATE + # MLIR libraries + MLIRIR + MLIRCAPIIR + MLIRSupport + MLIRTargetLLVMIRExport + + # LLVM libraries + LLVMCore + LLVMSupport + LLVMExecutionEngine + LLVMOrcJIT + LLVMRuntimeDyld + LLVMTarget + LLVMPasses + LLVMipo + LLVMTransformUtils + LLVMAnalysis + LLVMScalarOpts + LLVMInstCombine + LLVMAggressiveInstCombine + LLVMVectorize + LLVMMC + LLVMObject + LLVMBitWriter + LLVMBitReader + LLVMIRReader +) set_target_properties(wave_execution_engine PROPERTIES LINK_WHAT_YOU_USE TRUE) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 53b487a86..0d145e6be 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -5,76 +5,99 @@ #include "execution_engine.h" #include +#include #include -#include + +#include + +#include +#include +#include namespace nb = nanobind; +// Helper to convert llvm::Expected to Python (throw on error) +template +static T unwrapExpected(llvm::Expected expected, const char *context) { + if (!expected) { + std::string errorMessage; + llvm::raw_string_ostream os(errorMessage); + llvm::logAllUnhandledErrors(expected.takeError(), os); + throw std::runtime_error(std::string(context) + ": " + os.str()); + } + return std::move(*expected); +} + // Nanobind module definition for Python bindings NB_MODULE(wave_execution_engine, m) { m.doc() = "LLVM ExecutionEngine bindings for Wave JIT compilation"; - // Bind the WaveExecutionEngine class - nb::class_(m, "ExecutionEngine") - .def(nb::init<>(), - "Create a new WaveExecutionEngine instance") - .def("initialize", &wave::WaveExecutionEngine::initialize, - nb::arg("mlir_module_str"), - "Initialize the execution engine with an MLIR module string.\n\n" - "Args:\n" - " mlir_module_str: MLIR module as a string\n\n" - "Raises:\n" - " RuntimeError: If initialization fails or already initialized") - .def("load_llvm_ir", &wave::WaveExecutionEngine::load_llvm_ir, - nb::arg("ir_str"), - "Load a pre-compiled LLVM IR module.\n\n" - "Args:\n" - " ir_str: LLVM IR as a string\n\n" - "Raises:\n" - " RuntimeError: If loading fails") - .def("invoke", &wave::WaveExecutionEngine::invoke, - nb::arg("func_name"), nb::arg("args"), - "Invoke a function by name with the given arguments.\n\n" - "Args:\n" - " func_name: Name of the function to invoke\n" - " args: List of arguments as uint64_t values\n\n" - "Raises:\n" - " RuntimeError: If engine not initialized or function not found") - .def("get_function_address", &wave::WaveExecutionEngine::get_function_address, - nb::arg("func_name"), - "Get the address of a function by name.\n\n" + // Bind ExecutionEngineOptions + nb::class_(m, "ExecutionEngineOptions") + .def(nb::init<>(), "Create default ExecutionEngineOptions") + .def_rw("enable_object_cache", + &wave::ExecutionEngineOptions::enableObjectCache, + "Enable object cache for compiled code") + .def_rw("enable_gdb_notification_listener", + &wave::ExecutionEngineOptions::enableGDBNotificationListener, + "Enable GDB notification listener") + .def_rw("enable_perf_notification_listener", + &wave::ExecutionEngineOptions::enablePerfNotificationListener, + "Enable Perf notification listener"); + + // Bind ExecutionEngine class + nb::class_(m, "ExecutionEngine") + .def(nb::init(), nb::arg("options"), + "Create a new ExecutionEngine with the given options.\n\n" "Args:\n" - " func_name: Name of the function\n\n" - "Returns:\n" - " Address of the function as an integer\n\n" - "Raises:\n" - " RuntimeError: If engine not initialized or function not found") - .def("is_initialized", &wave::WaveExecutionEngine::is_initialized, - "Check if the execution engine is initialized.\n\n" - "Returns:\n" - " True if initialized, False otherwise") - .def("optimize", &wave::WaveExecutionEngine::optimize, - nb::arg("opt_level") = 2, - "Optimize the module with the given optimization level.\n\n" + " options: ExecutionEngineOptions to configure the engine") + .def( + "load_module", + [](wave::ExecutionEngine &self, MlirModule cModule) { + auto module = unwrap(cModule); + auto handle = unwrapExpected(self.loadModule(module), + "Failed to load module"); + return reinterpret_cast(handle); + }, + nb::arg("module"), + "Compile and load an MLIR module into the execution engine.\n\n" + "Args:\n" + " module: MLIR module (MlirModule from MLIR C API)\n\n" + "Returns:\n" + " Module handle as integer\n\n" + "Raises:\n" + " RuntimeError: If compilation or loading fails") + .def( + "release_module", + [](wave::ExecutionEngine &self, uintptr_t handle) { + self.releaseModule(reinterpret_cast(handle)); + }, + nb::arg("handle"), + "Release a loaded module from the execution engine.\n\n" + "Args:\n" + " handle: Module handle returned from load_module") + .def( + "lookup", + [](const wave::ExecutionEngine &self, uintptr_t handle, + const std::string &name) { + auto ptr = unwrapExpected( + self.lookup(reinterpret_cast(handle), name), + "Failed to lookup function"); + return reinterpret_cast(ptr); + }, + nb::arg("handle"), nb::arg("name"), + "Look up a function in a loaded module.\n\n" + "Args:\n" + " handle: Module handle returned from load_module\n" + " name: Name of the function to look up\n\n" + "Returns:\n" + " Function address as integer\n\n" + "Raises:\n" + " RuntimeError: If function lookup fails") + .def("dump_to_object_file", &wave::ExecutionEngine::dumpToObjectFile, + nb::arg("filename"), + "Dump compiled object code to a file.\n\n" + "Note: Object cache must be enabled in ExecutionEngineOptions.\n\n" "Args:\n" - " opt_level: Optimization level (0-3, default 2)\n" - " 0 = No optimization\n" - " 1 = Basic optimizations\n" - " 2 = Standard optimizations (default)\n" - " 3 = Aggressive optimizations\n\n" - "Raises:\n" - " RuntimeError: If engine not initialized or invalid opt_level") - .def("dump_llvm_ir", &wave::WaveExecutionEngine::dump_llvm_ir, - "Dump the current LLVM IR module as a string.\n\n" - "Returns:\n" - " LLVM IR as a string\n\n" - "Raises:\n" - " RuntimeError: If engine not initialized"); - - // Bind the global initialization function - m.def("initialize_llvm_mlir", &wave::initialize_llvm_mlir, - "Initialize LLVM and MLIR infrastructure.\n\n" - "Must be called before creating any WaveExecutionEngine instances.\n\n" - "Raises:\n" - " RuntimeError: If initialization fails"); + " filename: Path to output file"); } diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index c3727b68e..4a2fee007 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -4,139 +4,392 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "execution_engine.h" -#include -namespace wave { +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -WaveExecutionEngine::WaveExecutionEngine() : initialized_(false) {} +#include +#include +#include -WaveExecutionEngine::~WaveExecutionEngine() { - if (initialized_) { - cleanup(); +#include +#include +#include +#include + +#define DEBUG_TYPE "wave-execution-engine" + +static llvm::OptimizationLevel mapToLevel(llvm::CodeGenOptLevel level) { + unsigned optimizeSize = 0; // TODO: unhardcode + + switch (level) { + default: + llvm_unreachable("Invalid optimization level!"); + + case llvm::CodeGenOptLevel::None: + return llvm::OptimizationLevel::O0; + + case llvm::CodeGenOptLevel::Less: + return llvm::OptimizationLevel::O1; + + case llvm::CodeGenOptLevel::Default: + switch (optimizeSize) { + default: + llvm_unreachable("Invalid optimization level for size!"); + + case 0: + return llvm::OptimizationLevel::O2; + + case 1: + return llvm::OptimizationLevel::Os; + + case 2: + return llvm::OptimizationLevel::Oz; + } + + case llvm::CodeGenOptLevel::Aggressive: + return llvm::OptimizationLevel::O3; } } -void WaveExecutionEngine::initialize(const std::string& mlir_module_str) { - if (initialized_) { - throw std::runtime_error("ExecutionEngine already initialized"); +static llvm::PipelineTuningOptions +getPipelineTuningOptions(llvm::CodeGenOptLevel optLevelVal) { + llvm::PipelineTuningOptions pto; + auto level = static_cast(optLevelVal); + + pto.LoopUnrolling = level > 0; + pto.LoopVectorization = level > 1; + pto.SLPVectorization = level > 1; + return pto; +} + +static void runOptimizationPasses(llvm::Module &M, llvm::TargetMachine &TM) { + llvm::CodeGenOptLevel optLevelVal = TM.getOptLevel(); + + llvm::LoopAnalysisManager lam; + llvm::FunctionAnalysisManager fam; + llvm::CGSCCAnalysisManager cgam; + llvm::ModuleAnalysisManager mam; + + llvm::PassInstrumentationCallbacks pic; + llvm::PrintPassOptions ppo; + ppo.Indent = false; + ppo.SkipAnalyses = false; + llvm::StandardInstrumentations si(M.getContext(), /*debugLogging*/ false, + /*verifyEach*/ true, ppo); + + si.registerCallbacks(pic, &mam); + + llvm::PassBuilder pb(&TM, getPipelineTuningOptions(optLevelVal)); + + llvm::ModulePassManager mpm; + + if (/*verify*/ true) { + pb.registerPipelineStartEPCallback( + [&](llvm::ModulePassManager &mpm, llvm::OptimizationLevel level) { + mpm.addPass(createModuleToFunctionPassAdaptor(llvm::VerifierPass())); + }); } - // TODO: Implement MLIR module parsing and LLVM ExecutionEngine creation - // Steps: - // 1. Parse MLIR module from string - // 2. Convert MLIR to LLVM IR - // 3. Create LLVM ExecutionEngine - // 4. JIT compile the module + // Register all the basic analyses with the managers. + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); - throw std::runtime_error("ExecutionEngine initialization not yet implemented"); -} + llvm::OptimizationLevel level = mapToLevel(optLevelVal); -void WaveExecutionEngine::load_llvm_ir(const std::string& ir_str) { - // TODO: Implement LLVM IR loading - // Steps: - // 1. Parse LLVM IR string - // 2. Create LLVM module from IR - // 3. Set up ExecutionEngine with the module - // 4. Finalize the module for execution + if (optLevelVal == llvm::CodeGenOptLevel::None) { + mpm = pb.buildO0DefaultPipeline(level); + } else { + mpm = pb.buildPerModuleDefaultPipeline(level); + } - throw std::runtime_error("LLVM IR loading not yet implemented"); + mpm.run(M, mam); } -void WaveExecutionEngine::invoke(const std::string& func_name, - const std::vector& args) { - if (!initialized_) { - throw std::runtime_error("ExecutionEngine not initialized"); +/// A simple object cache following Lang's LLJITWithObjectCache example. +class wave::ExecutionEngine::SimpleObjectCache : public llvm::ObjectCache { +public: + void notifyObjectCompiled(const llvm::Module *m, + llvm::MemoryBufferRef objBuffer) override { + cachedObjects[m->getModuleIdentifier()] = + llvm::MemoryBuffer::getMemBufferCopy(objBuffer.getBuffer(), + objBuffer.getBufferIdentifier()); } - // TODO: Implement function lookup and invocation - // Steps: - // 1. Look up function by name in ExecutionEngine - // 2. Get function pointer - // 3. Marshal arguments based on function signature - // 4. Invoke function with marshalled arguments - // 5. Return results (if any) + std::unique_ptr + getObject(const llvm::Module *m) override { + auto i = cachedObjects.find(m->getModuleIdentifier()); + if (i == cachedObjects.end()) { + LLVM_DEBUG(llvm::dbgs() << "No object for " << m->getModuleIdentifier() + << " in cache. Compiling.\n"); + return nullptr; + } + LLVM_DEBUG(llvm::dbgs() << "Object for " << m->getModuleIdentifier() + << " loaded from cache.\n"); + return llvm::MemoryBuffer::getMemBuffer(i->second->getMemBufferRef()); + } - throw std::runtime_error("Function invocation not yet implemented"); -} + /// Dump cached object to output file `filename`. + void dumpToObjectFile(llvm::StringRef outputFilename) { + // Set up the output file. + std::string errorMessage; + auto file = mlir::openOutputFile(outputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return; + } -uintptr_t WaveExecutionEngine::get_function_address(const std::string& func_name) { - if (!initialized_) { - throw std::runtime_error("ExecutionEngine not initialized"); + // Dump the object generated for a single module to the output file. + assert(cachedObjects.size() == 1 && "Expected only one object entry."); + auto &cachedObject = cachedObjects.begin()->second; + file->os() << cachedObject->getBuffer(); + file->keep(); } - // TODO: Implement function address lookup - // Steps: - // 1. Look up function by name in ExecutionEngine - // 2. Get function address using ExecutionEngine::getFunctionAddress() - // 3. Return the address as uintptr_t +private: + llvm::StringMap> cachedObjects; +}; - throw std::runtime_error("Function address lookup not yet implemented"); +/// Wrap a string into an llvm::StringError. +static llvm::Error makeStringError(const llvm::Twine &message) { + return llvm::make_error(message.str(), + llvm::inconvertibleErrorCode()); } -bool WaveExecutionEngine::is_initialized() const { - return initialized_; +// Setup LLVM target triple from the current machine. +static void setupModule(llvm::Module &M, llvm::TargetMachine &TM) { + M.setDataLayout(TM.createDataLayout()); + M.setTargetTriple(TM.getTargetTriple().normalize()); + for (auto &&func : M.functions()) { + if (!func.hasFnAttribute("target-cpu")) + func.addFnAttr("target-cpu", TM.getTargetCPU()); + + if (!func.hasFnAttribute("target-features")) { + auto featStr = TM.getTargetFeatureString(); + if (!featStr.empty()) + func.addFnAttr("target-features", featStr); + } + } } -void WaveExecutionEngine::optimize(int opt_level) { - if (!initialized_) { - throw std::runtime_error("ExecutionEngine not initialized"); +namespace { +class CustomCompiler : public llvm::orc::SimpleCompiler { +public: + using Transformer = std::function; + using AsmPrinter = std::function; + + CustomCompiler(Transformer t, AsmPrinter a, + std::unique_ptr TM, + llvm::ObjectCache *ObjCache = nullptr) + : SimpleCompiler(*TM, ObjCache), TM(std::move(TM)), + transformer(std::move(t)), printer(std::move(a)) {} + + llvm::Expected operator()(llvm::Module &M) override { + if (transformer) { + auto err = transformer(M); + if (err) + return err; + } + + setupModule(M, *TM); + runOptimizationPasses(M, *TM); + + if (printer) { + llvm::SmallVector buffer; + llvm::raw_svector_ostream os(buffer); + + llvm::legacy::PassManager PM; + if (TM->addPassesToEmitFile(PM, os, nullptr, + llvm::CodeGenFileType::AssemblyFile)) + return makeStringError("Target does not support Asm emission"); + + PM.run(M); + printer(llvm::StringRef(buffer.data(), buffer.size())); + } + + return llvm::orc::SimpleCompiler::operator()(M); } - if (opt_level < 0 || opt_level > 3) { - throw std::runtime_error("Invalid optimization level. Must be 0-3"); +private: + std::shared_ptr TM; + Transformer transformer; + AsmPrinter printer; +}; +} // namespace + +wave::ExecutionEngine::ExecutionEngine(const ExecutionEngineOptions &options) + : cache(options.enableObjectCache ? new SimpleObjectCache() : nullptr), + gdbListener(options.enableGDBNotificationListener + ? llvm::JITEventListener::createGDBRegistrationListener() + : nullptr), + perfListener(nullptr) { + if (options.enablePerfNotificationListener) { + if (auto *listener = llvm::JITEventListener::createPerfJITEventListener()) + perfListener = listener; + else if (auto *listener = + llvm::JITEventListener::createIntelJITEventListener()) + perfListener = listener; } - // TODO: Implement module optimization - // Steps: - // 1. Create LLVM PassManager - // 2. Add optimization passes based on opt_level: - // - O0: No optimization - // - O1: Basic optimizations - // - O2: Standard optimizations (default) - // - O3: Aggressive optimizations - // 3. Run passes on the module - - throw std::runtime_error("Module optimization not yet implemented"); + // Callback to create the object layer with symbol resolution to current + // process and dynamically linked libraries. + auto objectLinkingLayerCreator = [this](llvm::orc::ExecutionSession &session, + const llvm::Triple &targetTriple) { + auto objectLayer = + std::make_unique(session, []() { + return std::make_unique(); + }); + + // Register JIT event listeners if they are enabled. + if (gdbListener) + objectLayer->registerJITEventListener(*gdbListener); + if (perfListener) + objectLayer->registerJITEventListener(*perfListener); + + // COFF format binaries (Windows) need special handling to deal with + // exported symbol visibility. + // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer + if (targetTriple.isOSBinFormatCOFF()) { + objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); + objectLayer->setAutoClaimResponsibilityForObjectSymbols(true); + } + + return objectLayer; + }; + + // Callback to inspect the cache and recompile on demand. This follows Lang's + // LLJITWithObjectCache example. + auto compileFunctionCreator = + [this, jitCodeGenOptLevel = options.jitCodeGenOptLevel, + transformer = options.lateTransformer, + asmPrinter = options.asmPrinter](llvm::orc::JITTargetMachineBuilder jtmb) + -> llvm::Expected< + std::unique_ptr> { + if (jitCodeGenOptLevel) + jtmb.setCodeGenOptLevel(*jitCodeGenOptLevel); + auto tm = jtmb.createTargetMachine(); + if (!tm) + return tm.takeError(); + return std::make_unique(transformer, asmPrinter, + std::move(*tm), cache.get()); + }; + + auto tmBuilder = + llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); + + // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. + jit = cantFail(llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compileFunctionCreator) + .setObjectLinkingLayerCreator(objectLinkingLayerCreator) + .setJITTargetMachineBuilder(tmBuilder) + .create()); + + symbolMap = std::move(options.symbolMap); + transformer = std::move(options.transformer); } -std::string WaveExecutionEngine::dump_llvm_ir() const { - if (!initialized_) { - throw std::runtime_error("ExecutionEngine not initialized"); +wave::ExecutionEngine::~ExecutionEngine() {} + +llvm::Expected +wave::ExecutionEngine::loadModule(mlir::ModuleOp m) { + assert(m); + + std::unique_ptr ctx(new llvm::LLVMContext); + auto llvmModule = mlir::translateModuleToLLVMIR(m, *ctx); + if (!llvmModule) + return makeStringError("could not convert to LLVM IR"); + + // Add a ThreadSafemodule to the engine and return. + llvm::orc::ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx)); + if (transformer) + cantFail(tsm.withModuleDo( + [this](llvm::Module &module) { return transformer(module); })); + + llvm::orc::JITDylib *dylib; + while (true) { + auto uniqueName = + (llvm::Twine("module") + llvm::Twine(uniqueNameCounter++)).str(); + if (jit->getJITDylibByName(uniqueName)) + continue; + + auto res = jit->createJITDylib(std::move(uniqueName)); + if (!res) + return res.takeError(); + + dylib = &res.get(); + break; } + assert(dylib); - // TODO: Implement LLVM IR dumping - // Steps: - // 1. Get the LLVM module from ExecutionEngine - // 2. Convert module to string using raw_string_ostream - // 3. Return the string representation + auto dataLayout = jit->getDataLayout(); + dylib->addGenerator( + cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + dataLayout.getGlobalPrefix()))); - return "LLVM IR dump not yet implemented"; -} + if (symbolMap) + cantFail( + dylib->define(absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner( + dylib->getExecutionSession(), jit->getDataLayout()))))); -void WaveExecutionEngine::cleanup() { - // TODO: Implement cleanup of LLVM resources - // Steps: - // 1. Clean up ExecutionEngine - // 2. Clean up LLVM module - // 3. Clean up LLVM context - // 4. Clean up MLIR context (if used) + llvm::cantFail(jit->addIRModule(*dylib, std::move(tsm))); + llvm::cantFail(jit->initialize(*dylib)); + return static_cast(dylib); +} - initialized_ = false; +void wave::ExecutionEngine::releaseModule(ModuleHandle handle) { + assert(handle); + auto dylib = static_cast(handle); + llvm::cantFail(jit->deinitialize(*dylib)); + llvm::cantFail(jit->getExecutionSession().removeJITDylib(*dylib)); } -void initialize_llvm_mlir() { - // TODO: Implement LLVM and MLIR initialization - // Steps: - // 1. Initialize LLVM targets: - // - InitializeNativeTarget() - // - InitializeNativeTargetAsmPrinter() - // - InitializeNativeTargetAsmParser() - // 2. Register MLIR dialects (if needed): - // - registerAllDialects() - // - registerLLVMDialectTranslation() - // 3. Initialize LLVM passes (if needed) - - throw std::runtime_error("LLVM/MLIR initialization not yet implemented"); +llvm::Expected +wave::ExecutionEngine::lookup(wave::ExecutionEngine::ModuleHandle handle, + llvm::StringRef name) const { + assert(handle); + auto dylib = static_cast(handle); + auto expectedSymbol = jit->lookup(*dylib, name); + + // JIT lookup may return an Error referring to strings stored internally by + // the JIT. If the Error outlives the ExecutionEngine, it would want have a + // dangling reference, which is currently caught by an assertion inside JIT + // thanks to hand-rolled reference counting. Rewrap the error message into a + // string before returning. Alternatively, ORC JIT should consider copying + // the string into the error message. + if (!expectedSymbol) { + std::string errorMessage; + llvm::raw_string_ostream os(errorMessage); + llvm::handleAllErrors(expectedSymbol.takeError(), + [&os](llvm::ErrorInfoBase &ei) { ei.log(os); }); + return makeStringError(os.str()); + } + + if (void *fptr = expectedSymbol->toPtr()) + return fptr; + + return makeStringError("looked up function is null"); } -} // namespace wave +void wave::ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) { + if (cache == nullptr) { + llvm::errs() << "cannot dump ExecutionEngine object code to file: " + "object cache is disabled\n"; + return; + } + cache->dumpToObjectFile(filename); +} diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h index 0905e4b27..1f6bd10a7 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.h +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -3,94 +3,115 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#ifndef WAVE_EXECUTION_ENGINE_H -#define WAVE_EXECUTION_ENGINE_H +#pragma once -#include -#include -#include -#include +#include +#include +#include +#include -// Forward declarations for LLVM types (to be implemented) -// These will be replaced with actual LLVM includes when implementing -namespace llvm { - class Module; - class ExecutionEngine; - class LLVMContext; -} +#include +#include namespace mlir { - class MLIRContext; - class ModuleOp; +class ModuleOp; } +namespace llvm { +template class Expected; +class JITEventListener; + +namespace orc { +class LLJIT; +class MangleAndInterner; +} // namespace orc +} // namespace llvm + namespace wave { +struct ExecutionEngineOptions { + /// `jitCodeGenOptLevel`, when provided, is used as the optimization level for + /// target code generation. + std::optional jitCodeGenOptLevel = std::nullopt; + + /// If `enableObjectCache` is set, the JIT compiler will create one to store + /// the object generated for the given module. The contents of the cache can + /// be dumped to a file via the `dumpToObjectfile` method. + bool enableObjectCache = false; + + /// If enable `enableGDBNotificationListener` is set, the JIT compiler will + /// notify the llvm's global GDB notification listener. + bool enableGDBNotificationListener = true; + + /// If `enablePerfNotificationListener` is set, the JIT compiler will notify + /// the llvm's global Perf notification listener. + bool enablePerfNotificationListener = true; + + /// Register symbols with this ExecutionEngine. + std::function symbolMap; + + /// If `transformer` is provided, it will be called on the LLVM module during + /// JIT-compilation and can be used, e.g., for reporting or optimization. + std::function transformer; + + /// If `lateTransformer` is provided, it will be called on the LLVM module + /// just before final code generation and can be used, e.g., for reporting or + /// optimization. + std::function lateTransformer; + + /// If `asmPrinter` is provided, it will be called to print resulted assembly + /// just before final code generation. + std::function asmPrinter; +}; + +class ExecutionEngine { + class SimpleObjectCache; -/// LLVM ExecutionEngine wrapper for Wave JIT compilation -class WaveExecutionEngine { public: - WaveExecutionEngine(); - ~WaveExecutionEngine(); - - // Disable copy and move operations - WaveExecutionEngine(const WaveExecutionEngine&) = delete; - WaveExecutionEngine& operator=(const WaveExecutionEngine&) = delete; - WaveExecutionEngine(WaveExecutionEngine&&) = delete; - WaveExecutionEngine& operator=(WaveExecutionEngine&&) = delete; - - /// Initialize the execution engine with MLIR module - /// @param mlir_module_str MLIR module as a string - /// @throws std::runtime_error if initialization fails or already initialized - void initialize(const std::string& mlir_module_str); - - /// Load a pre-compiled LLVM IR module - /// @param ir_str LLVM IR as a string - /// @throws std::runtime_error if loading fails - void load_llvm_ir(const std::string& ir_str); - - /// Lookup and invoke a function by name - /// @param func_name Name of the function to invoke - /// @param args Vector of arguments as uint64_t values - /// @throws std::runtime_error if engine not initialized or function not found - void invoke(const std::string& func_name, const std::vector& args); - - /// Get pointer to a function by name - /// @param func_name Name of the function - /// @return Address of the function - /// @throws std::runtime_error if engine not initialized or function not found - uintptr_t get_function_address(const std::string& func_name); - - /// Check if engine is initialized - /// @return true if initialized, false otherwise - bool is_initialized() const; - - /// Optimize the module with given optimization level - /// @param opt_level Optimization level (0-3, default 2) - /// @throws std::runtime_error if engine not initialized - void optimize(int opt_level = 2); - - /// Dump the current LLVM IR module for debugging - /// @return LLVM IR as a string - /// @throws std::runtime_error if engine not initialized - std::string dump_llvm_ir() const; + using ModuleHandle = void *; + + ExecutionEngine(const ExecutionEngineOptions &options); + ~ExecutionEngine(); + + /// Compiles given module, adds it to execution engine and run its contructors + /// if any. + llvm::Expected loadModule(mlir::ModuleOp m); + + /// Runs module desctructors and removes it from execution engine. + void releaseModule(ModuleHandle handle); + + /// Looks up the original function with the given name and returns a + /// pointer to it. Propagates errors in case of failure. + llvm::Expected lookup(ModuleHandle handle, + llvm::StringRef name) const; + + /// Dump object code to output file `filename`. + void dumpToObjectFile(llvm::StringRef filename); private: - void cleanup(); + /// Ordering of llvmContext and jit is important for destruction purposes: the + /// jit must be destroyed before the context. + llvm::LLVMContext llvmContext; - bool initialized_; + /// Underlying LLJIT. + std::unique_ptr jit; - // TODO: Add private members for: - // std::unique_ptr llvm_context_; - // std::unique_ptr llvm_module_; - // std::unique_ptr execution_engine_; - // std::unique_ptr mlir_context_; -}; + /// Underlying cache. + std::unique_ptr cache; -/// Initialize LLVM and MLIR infrastructure -/// Must be called before creating any WaveExecutionEngine instances -/// @throws std::runtime_error if initialization fails -void initialize_llvm_mlir(); + /// GDB notification listener. + llvm::JITEventListener *gdbListener; -} // namespace wave + /// Perf notification listener. + llvm::JITEventListener *perfListener; + + /// Callback to get additional symbol definitions. + std::function symbolMap; + + /// If `transformer` is provided, it will be called on the LLVM module during + /// JIT-compilation and can be used, e.g., for reporting or optimization. + std::function transformer; -#endif // WAVE_EXECUTION_ENGINE_H + /// Id for unique module name generation. + int uniqueNameCounter = 0; +}; +} // namespace wave From ec2b19dfb60ba9805793111d864054f5071aa9d4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 9 Nov 2025 23:01:58 +0100 Subject: [PATCH 09/77] Fix LLVM API compatibility in ExecutionEngine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update execution_engine.cpp to work with newer LLVM API: - Remove normalize() call in setTargetTriple (now expects Triple directly) - Update objectLinkingLayerCreator lambda signature (no longer takes Triple param) - Fix RTDyldObjectLinkingLayer constructor to use GetMemoryManagerFunction - Add llvm:: namespace prefix to cantFail calls 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Signed-off-by: Ivan Butygin --- .../execution_engine/execution_engine.cpp | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index 4a2fee007..b8cf5de49 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -174,7 +174,7 @@ static llvm::Error makeStringError(const llvm::Twine &message) { // Setup LLVM target triple from the current machine. static void setupModule(llvm::Module &M, llvm::TargetMachine &TM) { M.setDataLayout(TM.createDataLayout()); - M.setTargetTriple(TM.getTargetTriple().normalize()); + M.setTargetTriple(TM.getTargetTriple()); for (auto &&func : M.functions()) { if (!func.hasFnAttribute("target-cpu")) func.addFnAttr("target-cpu", TM.getTargetCPU()); @@ -246,14 +246,23 @@ wave::ExecutionEngine::ExecutionEngine(const ExecutionEngineOptions &options) perfListener = listener; } + auto tmBuilder = + llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); + + // Get the target triple from the builder + auto targetTriple = tmBuilder.getTargetTriple(); + // Callback to create the object layer with symbol resolution to current // process and dynamically linked libraries. - auto objectLinkingLayerCreator = [this](llvm::orc::ExecutionSession &session, - const llvm::Triple &targetTriple) { - auto objectLayer = - std::make_unique(session, []() { - return std::make_unique(); - }); + auto objectLinkingLayerCreator = + [this, targetTriple](llvm::orc::ExecutionSession &session) + -> llvm::Expected> { + auto GetMemMgr = [](const llvm::MemoryBuffer &) { + return std::make_unique(); + }; + + auto objectLayer = std::make_unique( + session, GetMemMgr); // Register JIT event listeners if they are enabled. if (gdbListener) @@ -289,15 +298,13 @@ wave::ExecutionEngine::ExecutionEngine(const ExecutionEngineOptions &options) std::move(*tm), cache.get()); }; - auto tmBuilder = - llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); - // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. - jit = cantFail(llvm::orc::LLJITBuilder() - .setCompileFunctionCreator(compileFunctionCreator) - .setObjectLinkingLayerCreator(objectLinkingLayerCreator) - .setJITTargetMachineBuilder(tmBuilder) - .create()); + jit = llvm::cantFail( + llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compileFunctionCreator) + .setObjectLinkingLayerCreator(objectLinkingLayerCreator) + .setJITTargetMachineBuilder(tmBuilder) + .create()); symbolMap = std::move(options.symbolMap); transformer = std::move(options.transformer); From 7e4c6c8b7b321ac5f610455198c47ff4e52e70a4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 10 Nov 2025 01:01:40 +0100 Subject: [PATCH 10/77] new wrapper Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index ccce9e4e0..4093fbd6d 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -264,6 +264,18 @@ def __call__(self, *args, **kwargs): return invoke_with_profile(self.options, self.invoke, *args, **kwargs) +class WaveKernel2: + def __init__(self, options: WaveCompileOptions, module: Module | bytes): + self.options = options + self.module = module + + def __call__(self, *args, **kwargs): + return self.invoke(*args, **kwargs) + + def invoke(self, *args, **kwargs): + raise NotImplementedError("invoke is not implemented for WaveKernel2") + + def wave_compile( options: WaveCompileOptions, kernel: "LaunchableWave", @@ -446,6 +458,12 @@ def get_binary_path(): asm = _generate_asm_code(mb, options) if options.backend == "asm" and not options.compile_to_asm: _compile_asm_to_binary(asm, options) + if options.use_water_pipeline: + from .water import water_lowering_pipeline + + module = water_lowering_pipeline(mb.module_op, options.target) + return WaveKernel2(options, module) + elif not options.compile_to_mlir: # LLVM flow: only compile to VMFB when not in MLIR-only mode compiled_wave_vmfb = compile_to_vmfb(asm, options) From 8b34f6526eaf7aca48eb89abf0153b7363b07b9a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 10 Nov 2025 01:40:22 +0100 Subject: [PATCH 11/77] execution engine Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 61 +++++++++++-------- .../wave_execution_engine.lds | 10 +++ 2 files changed, 47 insertions(+), 24 deletions(-) create mode 100644 wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index a56c1a369..5d66168c0 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -23,6 +23,8 @@ if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() +find_package(Python COMPONENTS Interpreter Development REQUIRED) + # Detect the installed nanobind package and import it into CMake execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir @@ -31,7 +33,7 @@ execute_process( find_package(nanobind CONFIG REQUIRED) # Build the core parts of nanobind once -nanobind_build_library(nanobind SHARED) +nanobind_build_library(nanobind STATIC) # Find LLVM and MLIR find_package(LLVM REQUIRED CONFIG) @@ -53,44 +55,55 @@ add_library(wave_execution_engine MODULE bindings.cpp ) +# Disable RTTI for execution_engine.cpp to avoid typeinfo symbol dependencies +# Only for GCC/Clang on Unix-like systems +if(UNIX AND (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")) + set_source_files_properties(execution_engine.cpp PROPERTIES COMPILE_FLAGS "-fno-rtti") +endif() + # Include current directory for header files target_include_directories(wave_execution_engine PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) # Link against nanobind target_link_libraries(wave_execution_engine PRIVATE nanobind) +get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(mlir_extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + # Link against LLVM and MLIR libraries target_link_libraries(wave_execution_engine PRIVATE - # MLIR libraries - MLIRIR - MLIRCAPIIR - MLIRSupport - MLIRTargetLLVMIRExport + Python::Python + Python::Module - # LLVM libraries - LLVMCore - LLVMSupport - LLVMExecutionEngine + LLVM${LLVM_NATIVE_ARCH}AsmParser + LLVM${LLVM_NATIVE_ARCH}CodeGen + LLVM${LLVM_NATIVE_ARCH}Desc LLVMOrcJIT - LLVMRuntimeDyld LLVMTarget - LLVMPasses - LLVMipo - LLVMTransformUtils - LLVMAnalysis - LLVMScalarOpts - LLVMInstCombine - LLVMAggressiveInstCombine - LLVMVectorize - LLVMMC - LLVMObject - LLVMBitWriter - LLVMBitReader - LLVMIRReader + MLIRCAPIDebug + MLIRCAPIIR + MLIRCAPIInterfaces + MLIRCAPITransforms + MLIRPass + MLIRToLLVMIRTranslationRegistration + ${mlir_dialect_libs} + ${mlir_conversion_libs} + ${mlir_extension_libs} ) set_target_properties(wave_execution_engine PROPERTIES LINK_WHAT_YOU_USE TRUE) +# Force all symbols to be resolved at compile time (Linux/Unix only) +if(UNIX AND NOT APPLE) + target_link_options(wave_execution_engine PRIVATE + -Wl,--no-undefined + -Wl,-z,defs + # Use linker script to hide all symbols except Python module entry point + -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/wave_execution_engine.lds + ) +endif() + # Enable size optimizations nanobind_opt_size(wave_execution_engine) diff --git a/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds new file mode 100644 index 000000000..5c9d44621 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds @@ -0,0 +1,10 @@ +# Linker version script for wave_execution_engine +# Hide all symbols except the Python module entry point +{ + global: + # Python module initialization function + PyInit_wave_execution_engine; + local: + # Hide all other symbols + *; +}; From 8038c3547eae621c39e8eba2d76cd8578b6ed5ad Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 11 Nov 2025 13:28:32 +0100 Subject: [PATCH 12/77] init llvm target Signed-off-by: Ivan Butygin --- .../wave/execution_engine/execution_engine.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index b8cf5de49..ef6d3fcc9 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -29,8 +30,19 @@ #include #include +#include + #define DEBUG_TYPE "wave-execution-engine" +// Ensure LLVM native target is initialized only once +static std::once_flag llvmInitFlag; + +static void initializeLLVMTarget() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); +} + static llvm::OptimizationLevel mapToLevel(llvm::CodeGenOptLevel level) { unsigned optimizeSize = 0; // TODO: unhardcode @@ -246,6 +258,9 @@ wave::ExecutionEngine::ExecutionEngine(const ExecutionEngineOptions &options) perfListener = listener; } + // Initialize LLVM native target (only once per process) + std::call_once(llvmInitFlag, initializeLLVMTarget); + auto tmBuilder = llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); From 2a4c7fd56cef96327c3d44f97fab3c2c9537f5fc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 11 Nov 2025 13:38:51 +0100 Subject: [PATCH 13/77] execution_engine python wrapper Signed-off-by: Ivan Butygin --- .../wave/test_execution_engine_wrapper.py | 181 ++++++++++++++++++ .../kernel/wave/execution_engine/__init__.py | 37 ++++ .../kernel/wave/execution_engine/bindings.cpp | 3 +- .../wave/execution_engine/execution_engine.py | 127 ++++++++++++ 4 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/wave/test_execution_engine_wrapper.py create mode 100644 wave_lang/kernel/wave/execution_engine/__init__.py create mode 100644 wave_lang/kernel/wave/execution_engine/execution_engine.py diff --git a/tests/kernel/wave/test_execution_engine_wrapper.py b/tests/kernel/wave/test_execution_engine_wrapper.py new file mode 100644 index 000000000..e8af0b5f5 --- /dev/null +++ b/tests/kernel/wave/test_execution_engine_wrapper.py @@ -0,0 +1,181 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Tests for the ExecutionEngine wrapper with weak reference caching. +""" + +import gc +import os +import pytest +import weakref + +try: + from wave_lang.kernel.wave.execution_engine import ( + get_execution_engine, + clear_engine_cache, + is_engine_cached, + ) + + EXECUTION_ENGINE_AVAILABLE = True +except ImportError: + EXECUTION_ENGINE_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not EXECUTION_ENGINE_AVAILABLE, + reason="ExecutionEngine not available (C++ extension not built)", +) + + +ENV_VARS = [ + "WAVE_ENABLE_OBJECT_CACHE", + "WAVE_ENABLE_GDB_LISTENER", + "WAVE_ENABLE_PERF_LISTENER", +] + + +@pytest.fixture +def clean_env(): + """Fixture to clean cache and environment variables before/after each test.""" + clear_engine_cache() + + # Save original env vars + orig_env = {} + for key in ENV_VARS: + orig_env[key] = os.environ.get(key) + if key in os.environ: + del os.environ[key] + + yield + + # Restore original env vars + for key, value in orig_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + clear_engine_cache() + gc.collect() + + +def test_basic_creation(clean_env): + """Test basic creation of execution engine.""" + engine = get_execution_engine() + assert engine is not None + assert is_engine_cached() + + +def test_cache_returns_same_instance(clean_env): + """Test that cache returns the same instance.""" + engine1 = get_execution_engine() + engine2 = get_execution_engine() + + # Should be the exact same object + assert engine1 is engine2 + + +def test_weak_reference_cleanup(clean_env): + """Test that engine is cleaned up when references are released.""" + # Create engine and get weak reference + engine = get_execution_engine() + weak_ref = weakref.ref(engine) + + # Verify it's cached + assert is_engine_cached() + + # Delete strong reference + del engine + gc.collect() + + # Weak reference should be dead + assert weak_ref() is None + + # Cache should report no engine + assert not is_engine_cached() + + +def test_clear_cache(clean_env): + """Test that clear_cache removes cached engine.""" + engine = get_execution_engine() + assert is_engine_cached() + + # Clear cache + clear_engine_cache() + + # Cache should be empty + assert not is_engine_cached() + + # Engine should still be alive (we have strong reference) + assert engine is not None + + +def test_env_var_object_cache(clean_env): + """Test WAVE_ENABLE_OBJECT_CACHE environment variable.""" + os.environ["WAVE_ENABLE_OBJECT_CACHE"] = "1" + + engine = get_execution_engine() + assert engine is not None + # We can't directly test if object cache is enabled without + # accessing internal state, but we verify the engine was created + + +def test_env_var_gdb_listener(clean_env): + """Test WAVE_ENABLE_GDB_LISTENER environment variable.""" + os.environ["WAVE_ENABLE_GDB_LISTENER"] = "1" + + engine = get_execution_engine() + assert engine is not None + + +def test_env_var_perf_listener(clean_env): + """Test WAVE_ENABLE_PERF_LISTENER environment variable.""" + os.environ["WAVE_ENABLE_PERF_LISTENER"] = "1" + + engine = get_execution_engine() + assert engine is not None + + +def test_all_env_vars(clean_env): + """Test all environment variables together.""" + os.environ["WAVE_ENABLE_OBJECT_CACHE"] = "1" + os.environ["WAVE_ENABLE_GDB_LISTENER"] = "1" + os.environ["WAVE_ENABLE_PERF_LISTENER"] = "1" + + engine = get_execution_engine() + assert engine is not None + assert is_engine_cached() + + +def test_multiple_references(clean_env): + """Test that multiple references to the same engine work correctly.""" + engine1 = get_execution_engine() + engine2 = get_execution_engine() + engine3 = get_execution_engine() + + # All should be the same object + assert engine1 is engine2 + assert engine2 is engine3 + + +def test_engine_interface(clean_env): + """Test that the engine provides the correct interface.""" + engine = get_execution_engine() + + # Check that required methods exist + assert hasattr(engine, "load_module") + assert hasattr(engine, "release_module") + assert hasattr(engine, "lookup") + assert hasattr(engine, "dump_to_object_file") + + # Check that methods are callable + assert callable(engine.load_module) + assert callable(engine.release_module) + assert callable(engine.lookup) + assert callable(engine.dump_to_object_file) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/wave_lang/kernel/wave/execution_engine/__init__.py b/wave_lang/kernel/wave/execution_engine/__init__.py new file mode 100644 index 000000000..1378bf122 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +ExecutionEngine module for Wave JIT compilation. + +This module provides both low-level C++ bindings and a high-level Python wrapper +for the LLVM-based execution engine used by Wave. The wrapper caches a single +ExecutionEngine instance using weak references and configures it via environment +variables. +""" + +# Import C++ bindings (may not be available if not built yet) +try: + from wave_execution_engine import ExecutionEngine, ExecutionEngineOptions +except ImportError: + ExecutionEngine = None + ExecutionEngineOptions = None + +# Import Python wrapper with caching +from .execution_engine import ( + get_execution_engine, + clear_engine_cache, + is_engine_cached, +) + +__all__ = [ + # C++ bindings + "ExecutionEngine", + "ExecutionEngineOptions", + # Python wrapper + "get_execution_engine", + "clear_engine_cache", + "is_engine_cached", +] diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 0d145e6be..60ef7c660 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -46,7 +46,8 @@ NB_MODULE(wave_execution_engine, m) { "Enable Perf notification listener"); // Bind ExecutionEngine class - nb::class_(m, "ExecutionEngine") + nb::class_(m, "ExecutionEngine", + nb::is_weak_referenceable()) .def(nb::init(), nb::arg("options"), "Create a new ExecutionEngine with the given options.\n\n" "Args:\n" diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py new file mode 100644 index 000000000..0a20f7896 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -0,0 +1,127 @@ +# Copyright 2025 The IREE Authors +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Python wrapper for ExecutionEngine with weak reference caching. + +This module provides a simple singleton wrapper around the native ExecutionEngine +that caches a single instance using weak references. Options are configured via +environment variables. +""" + +import os +import weakref +from typing import Optional + +try: + from wave_execution_engine import ExecutionEngine, ExecutionEngineOptions +except ImportError: + # Allow import to succeed even if C++ module not built yet + ExecutionEngine = None + ExecutionEngineOptions = None + + +# Global weak reference to the cached ExecutionEngine instance +_cached_engine: Optional[weakref.ref] = None + + +def _create_options_from_env() -> "ExecutionEngineOptions": + """ + Create ExecutionEngineOptions from environment variables. + + Environment Variables: + WAVE_ENABLE_OBJECT_CACHE: Enable object cache (default: 0) + WAVE_ENABLE_GDB_LISTENER: Enable GDB notification listener (default: 0) + WAVE_ENABLE_PERF_LISTENER: Enable Perf notification listener (default: 0) + + Returns: + ExecutionEngineOptions configured from environment + """ + if ExecutionEngineOptions is None: + raise RuntimeError( + "wave_execution_engine module not available. " + "Ensure the C++ extension is built and installed." + ) + + options = ExecutionEngineOptions() + + # Read options from environment variables + def _env_enabled(var: str, default: str = "0") -> bool: + return bool(int(os.environ.get(var, default))) + + options.enable_object_cache = _env_enabled("WAVE_ENABLE_OBJECT_CACHE") + options.enable_gdb_notification_listener = _env_enabled("WAVE_ENABLE_GDB_LISTENER") + options.enable_perf_notification_listener = _env_enabled( + "WAVE_ENABLE_PERF_LISTENER" + ) + + return options + + +def get_execution_engine() -> "ExecutionEngine": + """ + Get or create the global ExecutionEngine instance. + + This function maintains a single cached ExecutionEngine instance using + weak references. If the cached instance has been garbage collected, a + new one is created. Options are configured via environment variables. + + Returns: + ExecutionEngine instance + + Example: + >>> engine = get_execution_engine() + >>> handle = engine.load_module(my_mlir_module) + >>> func_ptr = engine.lookup(handle, "my_function") + >>> engine.release_module(handle) + """ + global _cached_engine + + # Try to get cached instance + if _cached_engine is not None: + engine = _cached_engine() + if engine is not None: + return engine + + # Create new instance with options from environment + options = _create_options_from_env() + engine = ExecutionEngine(options) + + # Cache using weak reference + _cached_engine = weakref.ref(engine) + + return engine + + +def clear_engine_cache(): + """ + Clear the cached execution engine instance. + + Note: The engine will only be destroyed if there are no other + references to it. If you're holding a reference, the engine + will remain alive until that reference is released. + """ + global _cached_engine + _cached_engine = None + + +def is_engine_cached() -> bool: + """ + Check if an execution engine is currently cached. + + Returns: + True if an engine is cached and still alive, False otherwise + """ + global _cached_engine + if _cached_engine is None: + return False + return _cached_engine() is not None + + +__all__ = [ + "get_execution_engine", + "clear_engine_cache", + "is_engine_cached", +] From 1bfa8d8cc4669d07636bdacdd1aac0c5e8d53cc9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 11 Nov 2025 13:53:57 +0100 Subject: [PATCH 14/77] load module from bytecode Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/bindings.cpp | 21 +++++++++++++ .../execution_engine/execution_engine.cpp | 30 +++++++++++++++++++ .../wave/execution_engine/execution_engine.h | 12 +++++++- 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 60ef7c660..6f2c4c33d 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -9,10 +9,14 @@ #include #include +#include #include +#include #include #include +#include +#include namespace nb = nanobind; @@ -68,6 +72,23 @@ NB_MODULE(wave_execution_engine, m) { " Module handle as integer\n\n" "Raises:\n" " RuntimeError: If compilation or loading fails") + .def( + "load_module_from_bytecode", + [](wave::ExecutionEngine &self, nb::bytes bytecode) { + // Convert Python bytes to ArrayRef + llvm::ArrayRef data(bytecode.c_str(), bytecode.size()); + auto handle = unwrapExpected(self.loadModuleFromBytecode(data), + "Failed to load module from bytecode"); + return reinterpret_cast(handle); + }, + nb::arg("bytecode"), + "Deserialize MLIR bytecode and load it into the execution engine.\n\n" + "Args:\n" + " bytecode: MLIR module serialized as bytecode (bytes)\n\n" + "Returns:\n" + " Module handle as integer\n\n" + "Raises:\n" + " RuntimeError: If deserialization, compilation or loading fails") .def( "release_module", [](wave::ExecutionEngine &self, uintptr_t handle) { diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index ef6d3fcc9..a13536f8e 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -21,7 +21,12 @@ #include #include +#include #include +#include +#include +#include +#include #include #include @@ -373,6 +378,31 @@ wave::ExecutionEngine::loadModule(mlir::ModuleOp m) { return static_cast(dylib); } +llvm::Expected +wave::ExecutionEngine::loadModuleFromBytecode(llvm::ArrayRef bytecode) { + // Create MLIR context on demand if not already created + if (!mlirContext) + mlirContext = std::make_unique(); + + // Create memory buffer from bytecode + auto memoryBuffer = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(bytecode.data(), bytecode.size()), + /*BufferName=*/"bytecode", + /*RequiresNullTerminator=*/false); + + // Deserialize MLIR module from bytecode + mlir::OwningOpRef module( + mlir::ModuleOp::create(mlir::UnknownLoc::get(mlirContext.get()))); + + mlir::ParserConfig config(mlirContext.get()); + if (mlir::failed(mlir::readBytecodeFile(memoryBuffer->getMemBufferRef(), + module->getBody(), config))) + return makeStringError("Failed to deserialize MLIR bytecode"); + + // Load the deserialized module + return loadModule(module.get()); +} + void wave::ExecutionEngine::releaseModule(ModuleHandle handle) { assert(handle); auto dylib = static_cast(handle); diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h index 1f6bd10a7..a196d4db4 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.h +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -5,6 +5,7 @@ #pragma once +#include #include #include #include @@ -14,8 +15,9 @@ #include namespace mlir { +class MLIRContext; class ModuleOp; -} +} // namespace mlir namespace llvm { template class Expected; @@ -76,6 +78,11 @@ class ExecutionEngine { /// if any. llvm::Expected loadModule(mlir::ModuleOp m); + /// Deserializes MLIR bytecode from a memory buffer, compiles it, and loads + /// it into the execution engine. + llvm::Expected + loadModuleFromBytecode(llvm::ArrayRef bytecode); + /// Runs module desctructors and removes it from execution engine. void releaseModule(ModuleHandle handle); @@ -88,6 +95,9 @@ class ExecutionEngine { void dumpToObjectFile(llvm::StringRef filename); private: + /// MLIR context for deserializing bytecode. + std::unique_ptr mlirContext; + /// Ordering of llvmContext and jit is important for destruction purposes: the /// jit must be destroyed before the context. llvm::LLVMContext llvmContext; From b3b09f5c1515b93c55e73788800aeaa85e9b8a8d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 11 Nov 2025 15:24:09 +0100 Subject: [PATCH 15/77] bytecode loading Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 39 ++++++++++++++++++- .../wave/execution_engine/CMakeLists.txt | 1 + .../execution_engine/execution_engine.cpp | 8 +++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 4093fbd6d..924ba4683 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -267,13 +267,48 @@ def __call__(self, *args, **kwargs): class WaveKernel2: def __init__(self, options: WaveCompileOptions, module: Module | bytes): self.options = options - self.module = module + + self._engine = None + self._module_handle = None + + # Serialize MLIR module to bytecode if needed + if isinstance(module, bytes): + self.bytecode = module + else: + # Serialize the MLIR module to bytecode + import io + + bytecode_io = io.BytesIO() + module.operation.write_bytecode(bytecode_io) + self.bytecode = bytecode_io.getvalue() + + # Load module eagerly + from wave_lang.kernel.wave.execution_engine import get_execution_engine + + self._engine = get_execution_engine() + self._module_handle = self._engine.load_module_from_bytecode(self.bytecode) def __call__(self, *args, **kwargs): return self.invoke(*args, **kwargs) def invoke(self, *args, **kwargs): - raise NotImplementedError("invoke is not implemented for WaveKernel2") + """ + Invokes the wave kernel with the given arguments using the ExecutionEngine. + """ + # TODO: Implement argument marshalling and function invocation + # This will need to: + # 1. Convert Python/PyTorch arguments to C-compatible pointers + # 2. Look up the kernel function in the loaded module + # 3. Call the function with ctypes + # 4. Handle return values + raise NotImplementedError( + "WaveKernel2.invoke: argument marshalling not yet implemented" + ) + + def __del__(self): + """Clean up the loaded module when the kernel is destroyed.""" + if self._module_handle is not None and self._engine is not None: + self._engine.release_module(self._module_handle) def wave_compile( diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 5d66168c0..78ff8ce7a 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -86,6 +86,7 @@ target_link_libraries(wave_execution_engine PRIVATE MLIRCAPIInterfaces MLIRCAPITransforms MLIRPass + MLIRRegisterAllDialects MLIRToLLVMIRTranslationRegistration ${mlir_dialect_libs} ${mlir_conversion_libs} diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index a13536f8e..432d6b58d 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -381,8 +382,11 @@ wave::ExecutionEngine::loadModule(mlir::ModuleOp m) { llvm::Expected wave::ExecutionEngine::loadModuleFromBytecode(llvm::ArrayRef bytecode) { // Create MLIR context on demand if not already created - if (!mlirContext) - mlirContext = std::make_unique(); + if (!mlirContext) { + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlirContext = std::make_unique(registry); + } // Create memory buffer from bytecode auto memoryBuffer = llvm::MemoryBuffer::getMemBuffer( From 42839d158df2b31b17aa028ea0040e7d714afc19 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 11 Nov 2025 15:41:01 +0100 Subject: [PATCH 16/77] load from text Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 20 +++++------ .../kernel/wave/execution_engine/bindings.cpp | 16 +++++++++ .../execution_engine/execution_engine.cpp | 34 ++++++++++++++++--- .../wave/execution_engine/execution_engine.h | 4 +++ 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 924ba4683..f659071f8 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -265,28 +265,26 @@ def __call__(self, *args, **kwargs): class WaveKernel2: - def __init__(self, options: WaveCompileOptions, module: Module | bytes): + def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): self.options = options self._engine = None self._module_handle = None - # Serialize MLIR module to bytecode if needed - if isinstance(module, bytes): - self.bytecode = module + # Serialize MLIR module to text if needed + # TODO: investigate why bytecode deserialization is not working + if isinstance(module, (bytes, str)): + # Assume it's already MLIR text + mlir_asm = module.decode() if isinstance(module, bytes) else module else: - # Serialize the MLIR module to bytecode - import io - - bytecode_io = io.BytesIO() - module.operation.write_bytecode(bytecode_io) - self.bytecode = bytecode_io.getvalue() + # Serialize the MLIR module to text + mlir_asm = str(module) # Load module eagerly from wave_lang.kernel.wave.execution_engine import get_execution_engine self._engine = get_execution_engine() - self._module_handle = self._engine.load_module_from_bytecode(self.bytecode) + self._module_handle = self._engine.load_module_from_text(mlir_asm) def __call__(self, *args, **kwargs): return self.invoke(*args, **kwargs) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 6f2c4c33d..de9507299 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -89,6 +89,22 @@ NB_MODULE(wave_execution_engine, m) { " Module handle as integer\n\n" "Raises:\n" " RuntimeError: If deserialization, compilation or loading fails") + .def( + "load_module_from_text", + [](wave::ExecutionEngine &self, const std::string &mlirText) { + auto handle = unwrapExpected( + self.loadModuleFromText(llvm::StringRef(mlirText)), + "Failed to load module from text"); + return reinterpret_cast(handle); + }, + nb::arg("mlir_text"), + "Parse MLIR text and load it into the execution engine.\n\n" + "Args:\n" + " mlir_text: MLIR module as text string\n\n" + "Returns:\n" + " Module handle as integer\n\n" + "Raises:\n" + " RuntimeError: If parsing, compilation or loading fails") .def( "release_module", [](wave::ExecutionEngine &self, uintptr_t handle) { diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index 432d6b58d..d7640a24a 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -379,14 +380,19 @@ wave::ExecutionEngine::loadModule(mlir::ModuleOp m) { return static_cast(dylib); } +static mlir::DialectRegistry createMLIRContextRegistry() { + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllToLLVMIRTranslations(registry); + return registry; +} + llvm::Expected wave::ExecutionEngine::loadModuleFromBytecode(llvm::ArrayRef bytecode) { // Create MLIR context on demand if not already created - if (!mlirContext) { - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - mlirContext = std::make_unique(registry); - } + if (!mlirContext) + mlirContext = + std::make_unique(createMLIRContextRegistry()); // Create memory buffer from bytecode auto memoryBuffer = llvm::MemoryBuffer::getMemBuffer( @@ -407,6 +413,24 @@ wave::ExecutionEngine::loadModuleFromBytecode(llvm::ArrayRef bytecode) { return loadModule(module.get()); } +llvm::Expected +wave::ExecutionEngine::loadModuleFromText(llvm::StringRef mlirText) { + // Create MLIR context on demand if not already created + if (!mlirContext) + mlirContext = + std::make_unique(createMLIRContextRegistry()); + + // Parse MLIR text + mlir::OwningOpRef module = + mlir::parseSourceString(mlirText, mlirContext.get()); + + if (!module) + return makeStringError("Failed to parse MLIR text"); + + // Load the parsed module + return loadModule(module.get()); +} + void wave::ExecutionEngine::releaseModule(ModuleHandle handle) { assert(handle); auto dylib = static_cast(handle); diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h index a196d4db4..5e1454832 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.h +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -83,6 +83,10 @@ class ExecutionEngine { llvm::Expected loadModuleFromBytecode(llvm::ArrayRef bytecode); + /// Parses MLIR text from a string, compiles it, and loads it into the + /// execution engine. + llvm::Expected loadModuleFromText(llvm::StringRef mlirText); + /// Runs module desctructors and removes it from execution engine. void releaseModule(ModuleHandle handle); From 254cb7a0e7855244d2c8452581a817b9aec40570 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 15 Nov 2025 22:53:34 +0100 Subject: [PATCH 17/77] buffer utils Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 59 +++++++++++++--- .../wave/execution_engine/CMakeLists.txt | 1 + .../wave/execution_engine/buffer_utils.cpp | 69 +++++++++++++++++++ .../wave/execution_engine/buffer_utils.h | 39 +++++++++++ .../wave/execution_engine/execution_engine.py | 44 ++++++++++++ .../wave_execution_engine.lds | 4 +- 6 files changed, 206 insertions(+), 10 deletions(-) create mode 100644 wave_lang/kernel/wave/execution_engine/buffer_utils.cpp create mode 100644 wave_lang/kernel/wave/execution_engine/buffer_utils.h diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index f659071f8..5547c1a62 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -270,6 +270,7 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): self._engine = None self._module_handle = None + self._host_func_ptr = None # Serialize MLIR module to text if needed # TODO: investigate why bytecode deserialization is not working @@ -286,6 +287,17 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): self._engine = get_execution_engine() self._module_handle = self._engine.load_module_from_text(mlir_asm) + # Look up the host wrapper function + # The host wrapper is named "{kernel_name}_host_wrapper" by emit_host_func + func_name = f"{self.options.func_name}_host_wrapper" + try: + self._host_func_ptr = self._engine.lookup(self._module_handle, func_name) + except RuntimeError as e: + raise RuntimeError( + f"Failed to lookup function '{func_name}' in loaded module. " + f"Make sure the module was compiled with emit_host_func. Error: {e}" + ) + def __call__(self, *args, **kwargs): return self.invoke(*args, **kwargs) @@ -293,15 +305,44 @@ def invoke(self, *args, **kwargs): """ Invokes the wave kernel with the given arguments using the ExecutionEngine. """ - # TODO: Implement argument marshalling and function invocation - # This will need to: - # 1. Convert Python/PyTorch arguments to C-compatible pointers - # 2. Look up the kernel function in the loaded module - # 3. Call the function with ctypes - # 4. Handle return values - raise NotImplementedError( - "WaveKernel2.invoke: argument marshalling not yet implemented" - ) + import ctypes + + # The host wrapper signature is: + # void func(void* stream, void* arg0, void* arg1, ...) + # where stream is currently unused and args are PyObject* pointers to tensors + + # Create ctypes function type + # Return type is void, arguments are all void pointers + num_args = len(args) + arg_types = [ctypes.c_void_p] * (num_args + 1) # +1 for stream pointer + func_type = ctypes.CFUNCTYPE(None, *arg_types) + + # Cast the function pointer + cfunc = func_type(self._host_func_ptr) + + # Prepare arguments + # Stream pointer (currently unused, pass NULL) + stream_ptr = None + + # Convert PyTorch tensors to PyObject* using id() + # id() returns the memory address of the Python object + py_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + # Get pointer to PyObject* for this tensor + obj_ptr = id(arg) + py_args.append(obj_ptr) + else: + # For scalars, also pass as PyObject* + obj_ptr = id(arg) + py_args.append(obj_ptr) + + # Call the function + # Note: stream_ptr is None which ctypes will convert to NULL + cfunc(stream_ptr, *py_args) + + # Return None (kernel modifies output tensors in place) + return None def __del__(self): """Clean up the loaded module when the kernel is destroyed.""" diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 78ff8ce7a..e2bfb391c 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -53,6 +53,7 @@ add_definitions(${LLVM_DEFINITIONS_LIST}) add_library(wave_execution_engine MODULE execution_engine.cpp bindings.cpp + buffer_utils.cpp ) # Disable RTTI for execution_engine.cpp to avoid typeinfo symbol dependencies diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp new file mode 100644 index 000000000..f2beb967d --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -0,0 +1,69 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "buffer_utils.h" +#include +#include +#include + +// PyTorch C API definitions +// We use weak symbols so the code compiles even if PyTorch is not available +// The symbols will be resolved at runtime when PyTorch is loaded + +extern "C" { +// PyTorch Tensor C API functions (from torch/csrc/Module.h) +void *__attribute__((weak)) THPVariable_Unpack(PyObject *obj); +void *__attribute__((weak)) at_tensor_data_ptr(void *tensor); +int64_t __attribute__((weak)) at_tensor_numel(void *tensor); +int64_t __attribute__((weak)) at_tensor_element_size(void *tensor); +} + +/// Helper to check if PyTorch symbols are available +static bool isPyTorchAvailable() { + return THPVariable_Unpack != nullptr && at_tensor_data_ptr != nullptr && + at_tensor_numel != nullptr && at_tensor_element_size != nullptr; +} + +extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { + if (!obj) { + throw std::runtime_error("wave_get_buffer: NULL PyObject"); + } + + // Check if PyTorch is available + if (!isPyTorchAvailable()) { + throw std::runtime_error( + "wave_get_buffer: PyTorch C API symbols not found. " + "Make sure PyTorch is loaded before calling this function."); + } + + // Extract the ATen tensor from the PyTorch Python object + void *tensor = THPVariable_Unpack(obj); + if (!tensor) { + throw std::runtime_error( + "wave_get_buffer: Failed to unpack PyTorch tensor. " + "Object is not a valid torch.Tensor."); + } + + // Get the data pointer + void *data_ptr = at_tensor_data_ptr(tensor); + if (!data_ptr) { + throw std::runtime_error("wave_get_buffer: Tensor has NULL data pointer"); + } + + // Calculate total size in bytes + int64_t numel = at_tensor_numel(tensor); + int64_t element_size = at_tensor_element_size(tensor); + int64_t total_bytes = numel * element_size; + + // Create and return memref descriptor + MemRef1Di8 descriptor; + descriptor.basePtr = static_cast(data_ptr); + descriptor.data = static_cast(data_ptr); + descriptor.offset = 0; + descriptor.sizes[0] = total_bytes; + descriptor.strides[0] = 1; + + return descriptor; +} diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h new file mode 100644 index 000000000..b2f2801a4 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -0,0 +1,39 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include +#include + +/// StridedMemRefType is the descriptor structure used by MLIR for memrefs. +/// This matches the ABI used by MLIR's memref lowering. +template struct StridedMemRefType { + T *basePtr; // Pointer to the allocated buffer + T *data; // Aligned data pointer + int64_t offset; // Offset in elements + int64_t sizes[N]; // Size of each dimension + int64_t strides[N]; // Stride of each dimension in elements +}; + +/// Rank-1 memref descriptor for memref +using MemRef1Di8 = StridedMemRefType; + +extern "C" { + +/// Extract a raw buffer pointer from a PyObject (PyTorch tensor). +/// Returns a rank-1 memref descriptor: memref +/// +/// The returned descriptor has: +/// - basePtr: pointer to the raw data +/// - data: same as basePtr (no alignment offset) +/// - offset: 0 +/// - sizes[0]: total size in bytes +/// - strides[0]: 1 +/// +/// This function assumes the PyObject is a PyTorch tensor and uses +/// the PyTorch C API to extract the data pointer and size. +MemRef1Di8 wave_get_buffer(PyObject *obj); +} diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 0a20f7896..7d627b5ec 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -27,6 +27,50 @@ _cached_engine: Optional[weakref.ref] = None +def _get_wave_get_buffer_address(): + """ + Get the address of the wave_get_buffer function. + + This function is defined in buffer_utils.cpp and needs to be accessible + to JIT-compiled code. + + Returns: + Integer address of wave_get_buffer, or None if not available + """ + import ctypes + import sys + + # Try to find wave_get_buffer in the current process + # It should be loaded as part of the wave_execution_engine extension + try: + # Get handle to current process + if sys.platform == "linux": + RTLD_DEFAULT = ctypes.cast(0, ctypes.c_void_p) + libc = ctypes.CDLL(None) + dlsym = libc.dlsym + dlsym.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + dlsym.restype = ctypes.c_void_p + + addr = dlsym(RTLD_DEFAULT, b"wave_get_buffer") + if addr: + return addr + elif sys.platform == "darwin": + # macOS + RTLD_DEFAULT = ctypes.cast(-2, ctypes.c_void_p) + libc = ctypes.CDLL(None) + dlsym = libc.dlsym + dlsym.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + dlsym.restype = ctypes.c_void_p + + addr = dlsym(RTLD_DEFAULT, b"wave_get_buffer") + if addr: + return addr + except Exception: + pass + + return None + + def _create_options_from_env() -> "ExecutionEngineOptions": """ Create ExecutionEngineOptions from environment variables. diff --git a/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds index 5c9d44621..1a6f0f222 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds +++ b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds @@ -1,9 +1,11 @@ # Linker version script for wave_execution_engine -# Hide all symbols except the Python module entry point +# Hide all symbols except the Python module entry point and runtime helpers { global: # Python module initialization function PyInit_wave_execution_engine; + # Runtime helper functions for JIT-compiled code + wave_get_buffer; local: # Hide all other symbols *; From f21d59704ee1c3ced8cd3a6b992da8979ea1d79b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 15 Nov 2025 23:08:15 +0100 Subject: [PATCH 18/77] host wrapper name Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 5547c1a62..0d477f8b7 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -288,8 +288,7 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): self._module_handle = self._engine.load_module_from_text(mlir_asm) # Look up the host wrapper function - # The host wrapper is named "{kernel_name}_host_wrapper" by emit_host_func - func_name = f"{self.options.func_name}_host_wrapper" + func_name = self.options.func_name try: self._host_func_ptr = self._engine.lookup(self._module_handle, func_name) except RuntimeError as e: From 413aa851402404094c00a8b6bed60e7245fff4b4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 15 Nov 2025 23:20:58 +0100 Subject: [PATCH 19/77] move ctypes Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 57 ++++++++++++-------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 0d477f8b7..2bcf1ed91 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -297,6 +297,15 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): f"Make sure the module was compiled with emit_host_func. Error: {e}" ) + # Create ctypes function type once + # The host wrapper signature is: void func(void* stream, void* arg0, void* arg1, ...) + import ctypes + + num_kernel_args = len(self.options.kernel_usages) + arg_types = [ctypes.c_void_p] * (num_kernel_args + 1) # +1 for stream pointer + func_type = ctypes.CFUNCTYPE(None, *arg_types) + self._cfunc = func_type(self._host_func_ptr) + def __call__(self, *args, **kwargs): return self.invoke(*args, **kwargs) @@ -304,41 +313,17 @@ def invoke(self, *args, **kwargs): """ Invokes the wave kernel with the given arguments using the ExecutionEngine. """ - import ctypes - - # The host wrapper signature is: - # void func(void* stream, void* arg0, void* arg1, ...) - # where stream is currently unused and args are PyObject* pointers to tensors - - # Create ctypes function type - # Return type is void, arguments are all void pointers - num_args = len(args) - arg_types = [ctypes.c_void_p] * (num_args + 1) # +1 for stream pointer - func_type = ctypes.CFUNCTYPE(None, *arg_types) - - # Cast the function pointer - cfunc = func_type(self._host_func_ptr) - - # Prepare arguments + # Prepare arguments for the host wrapper # Stream pointer (currently unused, pass NULL) stream_ptr = None - # Convert PyTorch tensors to PyObject* using id() + # Convert arguments to PyObject* pointers using id() # id() returns the memory address of the Python object - py_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - # Get pointer to PyObject* for this tensor - obj_ptr = id(arg) - py_args.append(obj_ptr) - else: - # For scalars, also pass as PyObject* - obj_ptr = id(arg) - py_args.append(obj_ptr) + py_args = [id(arg) for arg in args] - # Call the function - # Note: stream_ptr is None which ctypes will convert to NULL - cfunc(stream_ptr, *py_args) + # Call the JIT-compiled host wrapper function + # Signature: void func(void* stream, void* arg0, void* arg1, ...) + self._cfunc(stream_ptr, *py_args) # Return None (kernel modifies output tensors in place) return None @@ -526,6 +511,12 @@ def get_binary_path(): # Handle ASM and LLVM backends in a clear, single-pass flow compiled_wave_vmfb = None + kernel_usages = [ + binding.kernel_buffer_type.usage + for binding in kernel_sig.kernel_buffer_bindings + ] + options.kernel_usages = kernel_usages + if options.compile_to_asm or options.backend == "asm": # ASM flow: generate AMDGCN assembly; optionally build a binary asm = _generate_asm_code(mb, options) @@ -558,12 +549,6 @@ def get_binary_path(): if options.create_vmfb_file: write_file(options.create_vmfb_file, "wb", compiled_wave_vmfb) - kernel_usages = [ - binding.kernel_buffer_type.usage - for binding in kernel_sig.kernel_buffer_bindings - ] - options.kernel_usages = kernel_usages - if is_cache_enabled() and not debug_arg_info: cache_manager.store_kernel( compiled_wave_vmfb, From 6173205a9d8e0c848bc4d9479bb11a282e108065 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 00:18:08 +0100 Subject: [PATCH 20/77] use current stream Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 2bcf1ed91..569811282 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -313,9 +313,8 @@ def invoke(self, *args, **kwargs): """ Invokes the wave kernel with the given arguments using the ExecutionEngine. """ - # Prepare arguments for the host wrapper - # Stream pointer (currently unused, pass NULL) - stream_ptr = None + # Get the current CUDA stream + stream_ptr = torch.cuda.current_stream().cuda_stream # Convert arguments to PyObject* pointers using id() # id() returns the memory address of the Python object From dd234223a459652af5399ee6729ee1854bb06e5b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 01:23:03 +0100 Subject: [PATCH 21/77] runtime Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 19 ++++++- .../wave/execution_engine/execution_engine.py | 56 +++++++++++++++++++ .../wave_execution_engine.lds | 2 - 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index e2bfb391c..20f3d2ddf 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -49,11 +49,26 @@ include_directories(${MLIR_INCLUDE_DIRS}) separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS_LIST}) +# Compile buffer utils as a separate shared library +add_library(wave_runtime_helpers SHARED + buffer_utils.cpp +) + +# Set visibility for wave_runtime_helpers +set_target_properties(wave_runtime_helpers PROPERTIES + CXX_VISIBILITY_PRESET default + VISIBILITY_INLINES_HIDDEN OFF +) + +# Link Python for wave_runtime_helpers +target_link_libraries(wave_runtime_helpers PRIVATE + Python::Python +) + # Compile an extension library add_library(wave_execution_engine MODULE execution_engine.cpp bindings.cpp - buffer_utils.cpp ) # Disable RTTI for execution_engine.cpp to avoid typeinfo symbol dependencies @@ -127,4 +142,4 @@ nanobind_extension(wave_execution_engine) # Set important linker flags nanobind_link_options(wave_execution_engine) -install(TARGETS wave_execution_engine DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +install(TARGETS wave_execution_engine wave_runtime_helpers DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 7d627b5ec..0e116e582 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -27,6 +27,59 @@ _cached_engine: Optional[weakref.ref] = None +def _load_runtime_helpers(): + """ + Load the wave_runtime_helpers shared library. + + This library contains runtime helper functions like wave_get_buffer that are + needed by JIT-compiled code. We load it once globally so that the symbols + are available when the ExecutionEngine uses DynamicLibrarySearchGenerator. + + Raises: + RuntimeError: If the library cannot be found or loaded + """ + import ctypes + import platform + from pathlib import Path + + # Find the library file + lib_name = { + "Linux": "libwave_runtime_helpers.so", + "Darwin": "libwave_runtime_helpers.dylib", + "Windows": "wave_runtime_helpers.dll", + }.get(platform.system()) + + if not lib_name: + raise RuntimeError( + f"Unsupported platform: {platform.system()}. " + "wave_runtime_helpers is only available on Linux, macOS, and Windows." + ) + + # Look for the library in the same directory as this module + module_dir = Path(__file__).parent + lib_path = module_dir / lib_name + + if not lib_path.exists(): + raise RuntimeError( + f"wave_runtime_helpers library not found at {lib_path}. " + "Please build the C++ extension: " + "cd wave_lang/kernel/wave/execution_engine && cmake -B build && cmake --build build" + ) + + # Load the library globally with RTLD_GLOBAL so symbols are visible + try: + if platform.system() == "Windows": + ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) + else: + # On Unix, use RTLD_GLOBAL to make symbols visible to dlsym + ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) + except OSError as e: + raise RuntimeError( + f"Failed to load wave_runtime_helpers from {lib_path}: {e}. " + "The library may be missing dependencies or corrupted." + ) from e + + def _get_wave_get_buffer_address(): """ Get the address of the wave_get_buffer function. @@ -129,6 +182,9 @@ def get_execution_engine() -> "ExecutionEngine": if engine is not None: return engine + # Load wave_runtime_helpers library to make wave_get_buffer available + _load_runtime_helpers() + # Create new instance with options from environment options = _create_options_from_env() engine = ExecutionEngine(options) diff --git a/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds index 1a6f0f222..2bb87b027 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds +++ b/wave_lang/kernel/wave/execution_engine/wave_execution_engine.lds @@ -4,8 +4,6 @@ global: # Python module initialization function PyInit_wave_execution_engine; - # Runtime helper functions for JIT-compiled code - wave_get_buffer; local: # Hide all other symbols *; From 0da8d146653fb4f58f9c8a5a0f5f77df537c7a4d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 01:31:26 +0100 Subject: [PATCH 22/77] fix import Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/execution_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 0e116e582..e4bc566eb 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -16,7 +16,7 @@ from typing import Optional try: - from wave_execution_engine import ExecutionEngine, ExecutionEngineOptions + from .wave_execution_engine import ExecutionEngine, ExecutionEngineOptions except ImportError: # Allow import to succeed even if C++ module not built yet ExecutionEngine = None From 3cc0aedce2546bb58ca9481126b4d6c9fdd8921c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 12:26:47 +0100 Subject: [PATCH 23/77] expose symbolMap Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/bindings.cpp | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index de9507299..46cdbf3cd 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -5,9 +5,12 @@ #include "execution_engine.h" #include +#include #include #include +#include +#include #include #include @@ -47,7 +50,31 @@ NB_MODULE(wave_execution_engine, m) { "Enable GDB notification listener") .def_rw("enable_perf_notification_listener", &wave::ExecutionEngineOptions::enablePerfNotificationListener, - "Enable Perf notification listener"); + "Enable Perf notification listener") + .def( + "set_symbol_map", + [](wave::ExecutionEngineOptions &self, + const std::map &symbols) { + // Convert Python dict to C++ symbolMap function + self.symbolMap = [symbols](llvm::orc::MangleAndInterner mangle) { + llvm::orc::SymbolMap symbolMap; + for (const auto &[name, address] : symbols) { + auto mangledName = mangle(name); + auto flags = llvm::JITSymbolFlags::Exported | + llvm::JITSymbolFlags::Callable; + symbolMap[mangledName] = llvm::orc::ExecutorSymbolDef( + llvm::orc::ExecutorAddr(address), flags); + } + return symbolMap; + }; + }, + nb::arg("symbols"), + "Set symbol map from a dictionary of symbol names to addresses.\n\n" + "Args:\n" + " symbols: Dictionary mapping symbol names (str) to addresses " + "(int)\n\n" + "Example:\n" + " options.set_symbol_map({'my_function': 0x12345678})"); // Bind ExecutionEngine class nb::class_(m, "ExecutionEngine", From 046b578c865feb6aafe0e0582759e470a4c4300b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 12:43:35 +0100 Subject: [PATCH 24/77] cleanup Signed-off-by: Ivan Butygin --- .../wave/execution_engine/execution_engine.py | 106 ++++++------------ 1 file changed, 37 insertions(+), 69 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index e4bc566eb..b0ec42f25 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -27,13 +27,15 @@ _cached_engine: Optional[weakref.ref] = None -def _load_runtime_helpers(): +def _load_library(lib_basename: str): """ - Load the wave_runtime_helpers shared library. + Load a shared library from the execution_engine directory. - This library contains runtime helper functions like wave_get_buffer that are - needed by JIT-compiled code. We load it once globally so that the symbols - are available when the ExecutionEngine uses DynamicLibrarySearchGenerator. + Args: + lib_basename: Base name of library (e.g., "wave_runtime_helpers") + + Returns: + ctypes.CDLL library handle Raises: RuntimeError: If the library cannot be found or loaded @@ -44,84 +46,47 @@ def _load_runtime_helpers(): # Find the library file lib_name = { - "Linux": "libwave_runtime_helpers.so", - "Darwin": "libwave_runtime_helpers.dylib", - "Windows": "wave_runtime_helpers.dll", + "Linux": f"lib{lib_basename}.so", + "Darwin": f"lib{lib_basename}.dylib", + "Windows": f"{lib_basename}.dll", }.get(platform.system()) if not lib_name: - raise RuntimeError( - f"Unsupported platform: {platform.system()}. " - "wave_runtime_helpers is only available on Linux, macOS, and Windows." - ) + raise RuntimeError(f"Unsupported platform: {platform.system()}.") # Look for the library in the same directory as this module module_dir = Path(__file__).parent lib_path = module_dir / lib_name if not lib_path.exists(): - raise RuntimeError( - f"wave_runtime_helpers library not found at {lib_path}. " - "Please build the C++ extension: " - "cd wave_lang/kernel/wave/execution_engine && cmake -B build && cmake --build build" - ) + raise RuntimeError(f"{lib_basename} library not found at {lib_path}. ") - # Load the library globally with RTLD_GLOBAL so symbols are visible - try: - if platform.system() == "Windows": - ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) - else: - # On Unix, use RTLD_GLOBAL to make symbols visible to dlsym - ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) - except OSError as e: - raise RuntimeError( - f"Failed to load wave_runtime_helpers from {lib_path}: {e}. " - "The library may be missing dependencies or corrupted." - ) from e + return ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) -def _get_wave_get_buffer_address(): +def _load_runtime_helpers(): """ - Get the address of the wave_get_buffer function. + Load the wave_runtime_helpers shared library and return symbol addresses. - This function is defined in buffer_utils.cpp and needs to be accessible - to JIT-compiled code. + This library contains runtime helper functions like wave_get_buffer that are + needed by JIT-compiled code. Returns: - Integer address of wave_get_buffer, or None if not available + Dictionary mapping symbol names to their addresses + + Raises: + RuntimeError: If the library cannot be found or loaded """ import ctypes - import sys - - # Try to find wave_get_buffer in the current process - # It should be loaded as part of the wave_execution_engine extension - try: - # Get handle to current process - if sys.platform == "linux": - RTLD_DEFAULT = ctypes.cast(0, ctypes.c_void_p) - libc = ctypes.CDLL(None) - dlsym = libc.dlsym - dlsym.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - dlsym.restype = ctypes.c_void_p - - addr = dlsym(RTLD_DEFAULT, b"wave_get_buffer") - if addr: - return addr - elif sys.platform == "darwin": - # macOS - RTLD_DEFAULT = ctypes.cast(-2, ctypes.c_void_p) - libc = ctypes.CDLL(None) - dlsym = libc.dlsym - dlsym.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - dlsym.restype = ctypes.c_void_p - - addr = dlsym(RTLD_DEFAULT, b"wave_get_buffer") - if addr: - return addr - except Exception: - pass - - return None + + lib = _load_library("wave_runtime_helpers") + + symbol_map = {} + + wave_get_buffer_addr = ctypes.cast(lib.wave_get_buffer, ctypes.c_void_p).value + symbol_map["wave_get_buffer"] = wave_get_buffer_addr + + return symbol_map def _create_options_from_env() -> "ExecutionEngineOptions": @@ -182,14 +147,17 @@ def get_execution_engine() -> "ExecutionEngine": if engine is not None: return engine - # Load wave_runtime_helpers library to make wave_get_buffer available - _load_runtime_helpers() + symbol_map = {} + + symbol_map.update(_load_runtime_helpers()) - # Create new instance with options from environment options = _create_options_from_env() + + if symbol_map: + options.set_symbol_map(symbol_map) + engine = ExecutionEngine(options) - # Cache using weak reference _cached_engine = weakref.ref(engine) return engine From be6f25ea27d96d92b2772f625eab5325604f2f94 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 12:51:43 +0100 Subject: [PATCH 25/77] hip_runtime Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 13 ++++- .../wave/execution_engine/execution_engine.py | 35 +++++++++++++ .../execution_engine/wave_hip_runtime.cpp | 52 +++++++++++++++++++ .../wave/execution_engine/wave_hip_runtime.h | 49 +++++++++++++++++ 4 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp create mode 100644 wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 20f3d2ddf..dd7a9bd0c 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -65,6 +65,17 @@ target_link_libraries(wave_runtime_helpers PRIVATE Python::Python ) +# Compile HIP runtime as a separate shared library +add_library(wave_hip_runtime SHARED + wave_hip_runtime.cpp +) + +# Set visibility for wave_hip_runtime +set_target_properties(wave_hip_runtime PROPERTIES + CXX_VISIBILITY_PRESET default + VISIBILITY_INLINES_HIDDEN OFF +) + # Compile an extension library add_library(wave_execution_engine MODULE execution_engine.cpp @@ -142,4 +153,4 @@ nanobind_extension(wave_execution_engine) # Set important linker flags nanobind_link_options(wave_execution_engine) -install(TARGETS wave_execution_engine wave_runtime_helpers DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +install(TARGETS wave_execution_engine wave_runtime_helpers wave_hip_runtime DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index b0ec42f25..85a63ac2d 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -89,6 +89,37 @@ def _load_runtime_helpers(): return symbol_map +def _load_hip_runtime(): + """ + Load the wave_hip_runtime shared library and return symbol addresses. + + This library contains HIP runtime functions for kernel loading and launching. + + Returns: + Dictionary mapping symbol names to their addresses + + Raises: + RuntimeError: If the library cannot be found or loaded + """ + import ctypes + + lib = _load_library("wave_hip_runtime") + + symbol_map = {} + + # Register HIP runtime functions + wave_load_kernel_addr = ctypes.cast(lib.wave_load_kernel, ctypes.c_void_p).value + symbol_map["wave_load_kernel"] = wave_load_kernel_addr + + wave_launch_kernel_addr = ctypes.cast(lib.wave_launch_kernel, ctypes.c_void_p).value + symbol_map["wave_launch_kernel"] = wave_launch_kernel_addr + + wave_unload_kernel_addr = ctypes.cast(lib.wave_unload_kernel, ctypes.c_void_p).value + symbol_map["wave_unload_kernel"] = wave_unload_kernel_addr + + return symbol_map + + def _create_options_from_env() -> "ExecutionEngineOptions": """ Create ExecutionEngineOptions from environment variables. @@ -149,8 +180,12 @@ def get_execution_engine() -> "ExecutionEngine": symbol_map = {} + # Load runtime helpers (wave_get_buffer, etc.) symbol_map.update(_load_runtime_helpers()) + # Load HIP runtime (wave_load_kernel, wave_launch_kernel, etc.) + symbol_map.update(_load_hip_runtime()) + options = _create_options_from_env() if symbol_map: diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp new file mode 100644 index 000000000..a32cee3b9 --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp @@ -0,0 +1,52 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "wave_hip_runtime.h" +#include +#include +#include + +// TODO: Include HIP headers when implementing +// #include + +extern "C" WaveKernelHandle wave_load_kernel(const char *binary_path, + const char *kernel_name) { + // TODO: Implement kernel loading + // 1. Load binary file + // 2. Use hipModuleLoadData or similar + // 3. Get kernel function handle + // 4. Return opaque handle + + fprintf(stderr, + "wave_load_kernel: stub implementation (binary=%s, kernel=%s)\n", + binary_path, kernel_name); + return nullptr; +} + +extern "C" int wave_launch_kernel(WaveKernelHandle handle, void *stream, + uint32_t grid_x, uint32_t grid_y, + uint32_t grid_z, uint32_t block_x, + uint32_t block_y, uint32_t block_z, + void **args, size_t num_args) { + // TODO: Implement kernel launch + // 1. Validate handle + // 2. Set up launch parameters + // 3. Use hipModuleLaunchKernel or similar + // 4. Return status + + fprintf(stderr, + "wave_launch_kernel: stub implementation (grid=[%u,%u,%u], " + "block=[%u,%u,%u], args=%zu)\n", + grid_x, grid_y, grid_z, block_x, block_y, block_z, num_args); + return -1; // Not implemented +} + +extern "C" void wave_unload_kernel(WaveKernelHandle handle) { + // TODO: Implement kernel unloading + // 1. Free module resources + // 2. Clean up handle + + fprintf(stderr, "wave_unload_kernel: stub implementation\n"); +} diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h new file mode 100644 index 000000000..3ebcb896e --- /dev/null +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h @@ -0,0 +1,49 @@ +// Copyright 2025 The IREE Authors +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include +#include + +extern "C" { + +/// Opaque kernel handle type +typedef void *WaveKernelHandle; + +/// Load a GPU kernel from a binary file +/// +/// Args: +/// binary_path: Path to the kernel binary (.hsaco file) +/// kernel_name: Name of the kernel function to load +/// +/// Returns: +/// Opaque kernel handle, or nullptr on failure +WaveKernelHandle wave_load_kernel(const char *binary_path, + const char *kernel_name); + +/// Launch a GPU kernel +/// +/// Args: +/// handle: Kernel handle from wave_load_kernel +/// stream: HIP stream pointer +/// grid_x, grid_y, grid_z: Grid dimensions +/// block_x, block_y, block_z: Block dimensions +/// args: Pointer to array of kernel argument pointers +/// num_args: Number of kernel arguments +/// +/// Returns: +/// 0 on success, non-zero on failure +int wave_launch_kernel(WaveKernelHandle handle, void *stream, uint32_t grid_x, + uint32_t grid_y, uint32_t grid_z, uint32_t block_x, + uint32_t block_y, uint32_t block_z, void **args, + size_t num_args); + +/// Unload a GPU kernel +/// +/// Args: +/// handle: Kernel handle from wave_load_kernel +void wave_unload_kernel(WaveKernelHandle handle); +} From f7f458c143f0d9e0251bb2a45d5f3abc39e1e2b9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 13:03:05 +0100 Subject: [PATCH 26/77] torch utils Signed-off-by: Ivan Butygin --- .../wave/execution_engine/buffer_utils.cpp | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index f2beb967d..da4de743b 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -6,40 +6,57 @@ #include "buffer_utils.h" #include #include +#include #include +#include -// PyTorch C API definitions -// We use weak symbols so the code compiles even if PyTorch is not available -// The symbols will be resolved at runtime when PyTorch is loaded +// PyTorch Tensor C API function pointers +typedef void *(*THPVariable_Unpack_t)(PyObject *); +typedef void *(*at_tensor_data_ptr_t)(void *); +typedef int64_t (*at_tensor_numel_t)(void *); +typedef int64_t (*at_tensor_element_size_t)(void *); -extern "C" { -// PyTorch Tensor C API functions (from torch/csrc/Module.h) -void *__attribute__((weak)) THPVariable_Unpack(PyObject *obj); -void *__attribute__((weak)) at_tensor_data_ptr(void *tensor); -int64_t __attribute__((weak)) at_tensor_numel(void *tensor); -int64_t __attribute__((weak)) at_tensor_element_size(void *tensor); -} +// Global function pointers (initialized on first use) +static THPVariable_Unpack_t THPVariable_Unpack_ptr = nullptr; +static at_tensor_data_ptr_t at_tensor_data_ptr_ptr = nullptr; +static at_tensor_numel_t at_tensor_numel_ptr = nullptr; -/// Helper to check if PyTorch symbols are available -static bool isPyTorchAvailable() { - return THPVariable_Unpack != nullptr && at_tensor_data_ptr != nullptr && - at_tensor_numel != nullptr && at_tensor_element_size != nullptr; +// Helper to get symbol address +static void *get_symbol_address(void *handle, const char *symbol_name) { + return dlsym(handle, symbol_name); } -extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { - if (!obj) { - throw std::runtime_error("wave_get_buffer: NULL PyObject"); - } +// Macro to load a function pointer and check for errors +#define GET_FUNC(handle, name) \ + do { \ + name = \ + reinterpret_cast(get_symbol_address(handle, #name)); \ + if (!name) { \ + throw std::runtime_error("Failed to load PyTorch symbol: " + \ + std::string(#name)); \ + } \ + } while (0) - // Check if PyTorch is available - if (!isPyTorchAvailable()) { - throw std::runtime_error( - "wave_get_buffer: PyTorch C API symbols not found. " - "Make sure PyTorch is loaded before calling this function."); +/// Initialize PyTorch C API function pointers using dlsym +static void initPyTorchSymbols() { + if (THPVariable_Unpack_ptr != nullptr) { + return; // Already initialized } + // Use RTLD_DEFAULT to search all loaded libraries + void *handle = RTLD_DEFAULT; + + GET_FUNC(handle, THPVariable_Unpack_ptr); + GET_FUNC(handle, at_tensor_data_ptr_ptr); + GET_FUNC(handle, at_tensor_numel_ptr); +} + +extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { + // Initialize PyTorch symbols on first use + initPyTorchSymbols(); + // Extract the ATen tensor from the PyTorch Python object - void *tensor = THPVariable_Unpack(obj); + void *tensor = THPVariable_Unpack_ptr(obj); if (!tensor) { throw std::runtime_error( "wave_get_buffer: Failed to unpack PyTorch tensor. " @@ -47,22 +64,20 @@ extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { } // Get the data pointer - void *data_ptr = at_tensor_data_ptr(tensor); + void *data_ptr = at_tensor_data_ptr_ptr(tensor); if (!data_ptr) { throw std::runtime_error("wave_get_buffer: Tensor has NULL data pointer"); } // Calculate total size in bytes - int64_t numel = at_tensor_numel(tensor); - int64_t element_size = at_tensor_element_size(tensor); - int64_t total_bytes = numel * element_size; + int64_t numel = at_tensor_numel_ptr(tensor); // Create and return memref descriptor MemRef1Di8 descriptor; descriptor.basePtr = static_cast(data_ptr); descriptor.data = static_cast(data_ptr); descriptor.offset = 0; - descriptor.sizes[0] = total_bytes; + descriptor.sizes[0] = numel; descriptor.strides[0] = 1; return descriptor; From 16fbe8393436fe145206265ed98578fd9718a4be Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 19:21:22 +0100 Subject: [PATCH 27/77] buffer_utils WIP Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 10 +-- .../wave/execution_engine/CMakeLists.txt | 42 +--------- .../wave/execution_engine/buffer_utils.cpp | 76 ++++--------------- .../wave/execution_engine/buffer_utils.h | 4 +- .../wave/execution_engine/execution_engine.py | 6 +- 5 files changed, 26 insertions(+), 112 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 569811282..f3fa94e38 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -302,7 +302,9 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): import ctypes num_kernel_args = len(self.options.kernel_usages) - arg_types = [ctypes.c_void_p] * (num_kernel_args + 1) # +1 for stream pointer + arg_types = [ctypes.c_void_p] + [ + ctypes.py_object + ] * num_kernel_args # +1 for stream pointer func_type = ctypes.CFUNCTYPE(None, *arg_types) self._cfunc = func_type(self._host_func_ptr) @@ -316,13 +318,9 @@ def invoke(self, *args, **kwargs): # Get the current CUDA stream stream_ptr = torch.cuda.current_stream().cuda_stream - # Convert arguments to PyObject* pointers using id() - # id() returns the memory address of the Python object - py_args = [id(arg) for arg in args] - # Call the JIT-compiled host wrapper function # Signature: void func(void* stream, void* arg0, void* arg1, ...) - self._cfunc(stream_ptr, *py_args) + self._cfunc(stream_ptr, *args) # Return None (kernel modifies output tensors in place) return None diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index dd7a9bd0c..aae27c959 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -54,15 +54,10 @@ add_library(wave_runtime_helpers SHARED buffer_utils.cpp ) -# Set visibility for wave_runtime_helpers -set_target_properties(wave_runtime_helpers PROPERTIES - CXX_VISIBILITY_PRESET default - VISIBILITY_INLINES_HIDDEN OFF -) - -# Link Python for wave_runtime_helpers +# Link Python and nanobind for wave_runtime_helpers target_link_libraries(wave_runtime_helpers PRIVATE Python::Python + nanobind ) # Compile HIP runtime as a separate shared library @@ -70,20 +65,13 @@ add_library(wave_hip_runtime SHARED wave_hip_runtime.cpp ) -# Set visibility for wave_hip_runtime -set_target_properties(wave_hip_runtime PROPERTIES - CXX_VISIBILITY_PRESET default - VISIBILITY_INLINES_HIDDEN OFF -) - # Compile an extension library -add_library(wave_execution_engine MODULE +nanobind_add_module(wave_execution_engine NB_STATIC execution_engine.cpp bindings.cpp ) # Disable RTTI for execution_engine.cpp to avoid typeinfo symbol dependencies -# Only for GCC/Clang on Unix-like systems if(UNIX AND (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")) set_source_files_properties(execution_engine.cpp PROPERTIES COMPILE_FLAGS "-fno-rtti") endif() @@ -91,9 +79,6 @@ endif() # Include current directory for header files target_include_directories(wave_execution_engine PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -# Link against nanobind -target_link_libraries(wave_execution_engine PRIVATE nanobind) - get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(mlir_extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) @@ -132,25 +117,4 @@ if(UNIX AND NOT APPLE) ) endif() -# Enable size optimizations -nanobind_opt_size(wave_execution_engine) - -# Enable link time optimization -nanobind_lto(wave_execution_engine) - -# Set the default symbol visibility to 'hidden' -nanobind_set_visibility(wave_execution_engine) - -# Strip unneeded symbols and debug info from the binary (only active in release builds) -nanobind_strip(wave_execution_engine) - -# Disable the stack protector -nanobind_disable_stack_protector(wave_execution_engine) - -# Set the Python extension suffix -nanobind_extension(wave_execution_engine) - -# Set important linker flags -nanobind_link_options(wave_execution_engine) - install(TARGETS wave_execution_engine wave_runtime_helpers wave_hip_runtime DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index da4de743b..33ed60ec2 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -4,73 +4,23 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "buffer_utils.h" -#include -#include -#include +#include #include -#include -// PyTorch Tensor C API function pointers -typedef void *(*THPVariable_Unpack_t)(PyObject *); -typedef void *(*at_tensor_data_ptr_t)(void *); -typedef int64_t (*at_tensor_numel_t)(void *); -typedef int64_t (*at_tensor_element_size_t)(void *); +namespace nb = nanobind; -// Global function pointers (initialized on first use) -static THPVariable_Unpack_t THPVariable_Unpack_ptr = nullptr; -static at_tensor_data_ptr_t at_tensor_data_ptr_ptr = nullptr; -static at_tensor_numel_t at_tensor_numel_ptr = nullptr; +extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, + PyObject *obj_ptr) { + // Wrap PyObject* in nanobind::object for safe access + nb::object obj = nb::borrow(obj_ptr); -// Helper to get symbol address -static void *get_symbol_address(void *handle, const char *symbol_name) { - return dlsym(handle, symbol_name); -} - -// Macro to load a function pointer and check for errors -#define GET_FUNC(handle, name) \ - do { \ - name = \ - reinterpret_cast(get_symbol_address(handle, #name)); \ - if (!name) { \ - throw std::runtime_error("Failed to load PyTorch symbol: " + \ - std::string(#name)); \ - } \ - } while (0) - -/// Initialize PyTorch C API function pointers using dlsym -static void initPyTorchSymbols() { - if (THPVariable_Unpack_ptr != nullptr) { - return; // Already initialized - } - - // Use RTLD_DEFAULT to search all loaded libraries - void *handle = RTLD_DEFAULT; - - GET_FUNC(handle, THPVariable_Unpack_ptr); - GET_FUNC(handle, at_tensor_data_ptr_ptr); - GET_FUNC(handle, at_tensor_numel_ptr); -} - -extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { - // Initialize PyTorch symbols on first use - initPyTorchSymbols(); - - // Extract the ATen tensor from the PyTorch Python object - void *tensor = THPVariable_Unpack_ptr(obj); - if (!tensor) { - throw std::runtime_error( - "wave_get_buffer: Failed to unpack PyTorch tensor. " - "Object is not a valid torch.Tensor."); - } - - // Get the data pointer - void *data_ptr = at_tensor_data_ptr_ptr(tensor); - if (!data_ptr) { - throw std::runtime_error("wave_get_buffer: Tensor has NULL data pointer"); - } + // Call tensor.data_ptr() to get the data pointer + nb::object data_ptr_result = obj.attr("data_ptr")(); + void *data_ptr = + reinterpret_cast(nb::cast(data_ptr_result)); - // Calculate total size in bytes - int64_t numel = at_tensor_numel_ptr(tensor); + // Get tensor.numel() for the number of elements + int64_t numel = nb::cast(obj.attr("numel")()); // Create and return memref descriptor MemRef1Di8 descriptor; @@ -80,5 +30,5 @@ extern "C" MemRef1Di8 wave_get_buffer(PyObject *obj) { descriptor.sizes[0] = numel; descriptor.strides[0] = 1; - return descriptor; + *ret = descriptor; } diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index b2f2801a4..23f045873 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -30,10 +30,10 @@ extern "C" { /// - basePtr: pointer to the raw data /// - data: same as basePtr (no alignment offset) /// - offset: 0 -/// - sizes[0]: total size in bytes +/// - sizes[0]: number of elements /// - strides[0]: 1 /// /// This function assumes the PyObject is a PyTorch tensor and uses /// the PyTorch C API to extract the data pointer and size. -MemRef1Di8 wave_get_buffer(PyObject *obj); +void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, PyObject *obj); } diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 85a63ac2d..d619eab90 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -83,8 +83,10 @@ def _load_runtime_helpers(): symbol_map = {} - wave_get_buffer_addr = ctypes.cast(lib.wave_get_buffer, ctypes.c_void_p).value - symbol_map["wave_get_buffer"] = wave_get_buffer_addr + wave_get_buffer_addr = ctypes.cast( + lib._mlir_ciface_wave_get_buffer, ctypes.c_void_p + ).value + symbol_map["_mlir_ciface_wave_get_buffer"] = wave_get_buffer_addr return symbol_map From a8f9eebd539811460fea17594a82f5c2af38a175 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 19:51:17 +0100 Subject: [PATCH 28/77] python runtime Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 6 +- .../wave/execution_engine/buffer_utils.cpp | 90 ++++++++++++++----- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index aae27c959..d46ed4f04 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -32,9 +32,6 @@ execute_process( COMMAND_ERROR_IS_FATAL ANY) find_package(nanobind CONFIG REQUIRED) -# Build the core parts of nanobind once -nanobind_build_library(nanobind STATIC) - # Find LLVM and MLIR find_package(LLVM REQUIRED CONFIG) find_package(MLIR REQUIRED CONFIG) @@ -54,10 +51,9 @@ add_library(wave_runtime_helpers SHARED buffer_utils.cpp ) -# Link Python and nanobind for wave_runtime_helpers +# Link Python for wave_runtime_helpers target_link_libraries(wave_runtime_helpers PRIVATE Python::Python - nanobind ) # Compile HIP runtime as a separate shared library diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 33ed60ec2..81f816d3b 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -4,31 +4,75 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "buffer_utils.h" -#include -#include +#include -namespace nb = nanobind; +#include + +namespace { +struct GILState { + GILState() : gstate(PyGILState_Ensure()) {} + ~GILState() { PyGILState_Release(gstate); } + PyGILState_STATE gstate; +}; + +struct PyDeleter { + void operator()(PyObject *obj) const { Py_DECREF(obj); } +}; + +using PyObjectPtr = std::unique_ptr; +} // namespace extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, PyObject *obj_ptr) { - // Wrap PyObject* in nanobind::object for safe access - nb::object obj = nb::borrow(obj_ptr); - - // Call tensor.data_ptr() to get the data pointer - nb::object data_ptr_result = obj.attr("data_ptr")(); - void *data_ptr = - reinterpret_cast(nb::cast(data_ptr_result)); - - // Get tensor.numel() for the number of elements - int64_t numel = nb::cast(obj.attr("numel")()); - - // Create and return memref descriptor - MemRef1Di8 descriptor; - descriptor.basePtr = static_cast(data_ptr); - descriptor.data = static_cast(data_ptr); - descriptor.offset = 0; - descriptor.sizes[0] = numel; - descriptor.strides[0] = 1; - - *ret = descriptor; + GILState gil_state; + + // Get tensor.data_ptr() method and call it + PyObjectPtr data_ptr_method(PyObject_GetAttrString(obj_ptr, "data_ptr")); + if (!data_ptr_method) { + PyErr_Clear(); + return; + } + + PyObjectPtr data_ptr_result(PyObject_CallNoArgs(data_ptr_method.get())); + + if (!data_ptr_result) { + PyErr_Clear(); + return; + } + + // Convert Python int to pointer + void *data_ptr = PyLong_AsVoidPtr(data_ptr_result.get()); + + if (!data_ptr && PyErr_Occurred()) { + PyErr_Clear(); + return; + } + + // Get tensor.numel() method and call it + PyObjectPtr numel_method(PyObject_GetAttrString(obj_ptr, "numel")); + if (!numel_method) { + PyErr_Clear(); + return; + } + + PyObjectPtr numel_result(PyObject_CallNoArgs(numel_method.get())); + + if (!numel_result) { + PyErr_Clear(); + return; + } + + int64_t numel = PyLong_AsLongLong(numel_result.get()); + + if (PyErr_Occurred()) { + PyErr_Clear(); + return; + } + + // Fill in the memref descriptor + ret->basePtr = static_cast(data_ptr); + ret->data = static_cast(data_ptr); + ret->offset = 0; + ret->sizes[0] = numel; + ret->strides[0] = 1; } From 9c8848018d1346f9b69074e798020fd2612e6729 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 19:59:20 +0100 Subject: [PATCH 29/77] error handling Signed-off-by: Ivan Butygin --- .../wave/execution_engine/buffer_utils.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 81f816d3b..8febe95a5 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace { struct GILState { @@ -30,43 +31,42 @@ extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, PyObjectPtr data_ptr_method(PyObject_GetAttrString(obj_ptr, "data_ptr")); if (!data_ptr_method) { PyErr_Clear(); - return; + throw std::runtime_error( + "wave_get_buffer: Object does not have 'data_ptr' attribute"); } PyObjectPtr data_ptr_result(PyObject_CallNoArgs(data_ptr_method.get())); - if (!data_ptr_result) { PyErr_Clear(); - return; + throw std::runtime_error("wave_get_buffer: Failed to call data_ptr()"); } // Convert Python int to pointer void *data_ptr = PyLong_AsVoidPtr(data_ptr_result.get()); - if (!data_ptr && PyErr_Occurred()) { PyErr_Clear(); - return; + throw std::runtime_error( + "wave_get_buffer: data_ptr() did not return a valid pointer"); } // Get tensor.numel() method and call it PyObjectPtr numel_method(PyObject_GetAttrString(obj_ptr, "numel")); if (!numel_method) { PyErr_Clear(); - return; + throw std::runtime_error( + "wave_get_buffer: Object does not have 'numel' attribute"); } PyObjectPtr numel_result(PyObject_CallNoArgs(numel_method.get())); - if (!numel_result) { PyErr_Clear(); - return; + throw std::runtime_error("wave_get_buffer: Failed to call numel()"); } int64_t numel = PyLong_AsLongLong(numel_result.get()); - if (PyErr_Occurred()) { PyErr_Clear(); - return; + throw std::runtime_error("wave_get_buffer: numel() returned invalid value"); } // Fill in the memref descriptor From 00511944582bdbcc449fe942602d34f0a13643b6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 20:29:25 +0100 Subject: [PATCH 30/77] module loading Signed-off-by: Ivan Butygin --- .../wave/execution_engine/execution_engine.py | 3 - .../execution_engine/wave_hip_runtime.cpp | 175 ++++++++++++++---- .../wave/execution_engine/wave_hip_runtime.h | 42 +---- 3 files changed, 141 insertions(+), 79 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index d619eab90..37db19ffd 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -116,9 +116,6 @@ def _load_hip_runtime(): wave_launch_kernel_addr = ctypes.cast(lib.wave_launch_kernel, ctypes.c_void_p).value symbol_map["wave_launch_kernel"] = wave_launch_kernel_addr - wave_unload_kernel_addr = ctypes.cast(lib.wave_unload_kernel, ctypes.c_void_p).value - symbol_map["wave_unload_kernel"] = wave_unload_kernel_addr - return symbol_map diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp index a32cee3b9..2fd25dc52 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp @@ -4,49 +4,144 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "wave_hip_runtime.h" -#include -#include -#include - -// TODO: Include HIP headers when implementing -// #include - -extern "C" WaveKernelHandle wave_load_kernel(const char *binary_path, - const char *kernel_name) { - // TODO: Implement kernel loading - // 1. Load binary file - // 2. Use hipModuleLoadData or similar - // 3. Get kernel function handle - // 4. Return opaque handle - - fprintf(stderr, - "wave_load_kernel: stub implementation (binary=%s, kernel=%s)\n", - binary_path, kernel_name); - return nullptr; + +#include +#include +#include + +#if defined(__linux__) +#include // dlopen, dlsym, dlerror +using module_handle_t = void *; +#elif defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#include // LoadLibrary, GetProcAddress, GetLastError +using module_handle_t = HMODULE; +#else +#error "Unsupported platform" +#endif + +// HIP constants and types +#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void *)0x01) +#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void *)0x02) +#define HIP_LAUNCH_PARAM_END ((void *)0x03) + +using hipError_t = int; +using hipStream_t = void *; +using hipFunction_t = void *; +using hipModule_t = void *; + +using hipModuleLaunchKernel_t = hipError_t (*)(hipFunction_t, unsigned int, + unsigned int, unsigned int, + unsigned int, unsigned int, + unsigned int, unsigned int, + hipStream_t, void **, void **); + +using hipGetErrorName_t = const char *(*)(hipError_t); +using hipGetErrorString_t = const char *(*)(hipError_t); +using hipModuleUnload_t = hipError_t (*)(hipModule_t); +using hipModuleLoadData_t = hipError_t (*)(hipModule_t *, const void *); +using hipModuleGetFunction_t = hipError_t (*)(hipFunction_t *, hipModule_t, + const char *); + +// Global function pointers +static hipModuleLaunchKernel_t hipModuleLaunchKernel = nullptr; +static hipGetErrorName_t hipGetErrorName = nullptr; +static hipGetErrorString_t hipGetErrorString = nullptr; +static hipModuleUnload_t hipModuleUnload = nullptr; +static hipModuleLoadData_t hipModuleLoadData = nullptr; +static hipModuleGetFunction_t hipModuleGetFunction = nullptr; + +static void *get_symbol_address(module_handle_t module, + const char *symbol_name) { +#if defined(__linux__) + return dlsym(module, symbol_name); +#elif defined(_WIN32) + return reinterpret_cast(GetProcAddress(module, symbol_name)); +#endif } -extern "C" int wave_launch_kernel(WaveKernelHandle handle, void *stream, - uint32_t grid_x, uint32_t grid_y, - uint32_t grid_z, uint32_t block_x, - uint32_t block_y, uint32_t block_z, - void **args, size_t num_args) { - // TODO: Implement kernel launch - // 1. Validate handle - // 2. Set up launch parameters - // 3. Use hipModuleLaunchKernel or similar - // 4. Return status - - fprintf(stderr, - "wave_launch_kernel: stub implementation (grid=[%u,%u,%u], " - "block=[%u,%u,%u], args=%zu)\n", - grid_x, grid_y, grid_z, block_x, block_y, block_z, num_args); - return -1; // Not implemented +#define GET_FUNC(module, name) \ + do { \ + name = \ + reinterpret_cast(get_symbol_address(module, #name)); \ + if (!name) { \ + throw std::runtime_error("Failed to load symbol: " + \ + std::string(#name)); \ + } \ + } while (0) + +static void load_hip_functions() { + // Return early if already loaded + if (hipModuleLaunchKernel && hipGetErrorName && hipGetErrorString && + hipModuleUnload && hipModuleLoadData && hipModuleGetFunction) + return; + + module_handle_t module = nullptr; + +#if defined(__linux__) + module = dlopen("libamdhip64.so", RTLD_NOW); + if (!module) { + throw std::runtime_error("Failed to load libamdhip64.so: " + + std::string(dlerror())); + } +#elif defined(_WIN32) + module = LoadLibrary("amdhip64.dll"); + if (!module) { + DWORD error_code = GetLastError(); + throw std::runtime_error("Failed to load amdhip64.dll: error code " + + std::to_string(error_code)); + } +#endif + + GET_FUNC(module, hipModuleLaunchKernel); + GET_FUNC(module, hipGetErrorName); + GET_FUNC(module, hipGetErrorString); + GET_FUNC(module, hipModuleUnload); + GET_FUNC(module, hipModuleLoadData); + GET_FUNC(module, hipModuleGetFunction); } -extern "C" void wave_unload_kernel(WaveKernelHandle handle) { - // TODO: Implement kernel unloading - // 1. Free module resources - // 2. Clean up handle +#undef GET_FUNC + +#define HIP_CHECK_EXC(expr) \ + do { \ + hipError_t e = (expr); \ + if (e) { \ + const char *errName = hipGetErrorName(e); \ + const char *errMsg = hipGetErrorString(e); \ + std::ostringstream msg; \ + msg << "Error " << e << "(" << errName << ") " << __FILE__ << ":" \ + << __LINE__ << ": " << std::endl \ + << #expr << std::endl \ + << errMsg << std::endl; \ + throw std::runtime_error(msg.str()); \ + } \ + } while (0) + +extern "C" void *wave_load_kernel(void *stream, void **cached_kernel_handle, + const void *binary_pointer, + size_t /*binary_size*/, + const char *kernel_name) { + load_hip_functions(); + + hipFunction_t function = *cached_kernel_handle; + if (function) + return function; + + hipModule_t mod = nullptr; + HIP_CHECK_EXC(hipModuleLoadData(&mod, binary_pointer)); + HIP_CHECK_EXC(hipModuleGetFunction(&function, mod, kernel_name)); + *cached_kernel_handle = function; + + return function; +} - fprintf(stderr, "wave_unload_kernel: stub implementation\n"); +extern "C" void wave_launch_kernel(void *stream, void *function, + int shared_memory_bytes, int grid_x, + int grid_y, int grid_z, int block_x, + int block_y, int block_z, void **args, + int /*num_args*/) { + HIP_CHECK_EXC(hipModuleLaunchKernel(function, grid_x, grid_y, grid_z, block_x, + block_y, block_z, shared_memory_bytes, + stream, args, nullptr)); } diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h index 3ebcb896e..479c685b3 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h @@ -6,44 +6,14 @@ #pragma once #include -#include extern "C" { -/// Opaque kernel handle type -typedef void *WaveKernelHandle; +void *wave_load_kernel(void *stream, void **cached_kernel_handle, + const void *binary_pointer, size_t binary_size, + const char *kernel_name); -/// Load a GPU kernel from a binary file -/// -/// Args: -/// binary_path: Path to the kernel binary (.hsaco file) -/// kernel_name: Name of the kernel function to load -/// -/// Returns: -/// Opaque kernel handle, or nullptr on failure -WaveKernelHandle wave_load_kernel(const char *binary_path, - const char *kernel_name); - -/// Launch a GPU kernel -/// -/// Args: -/// handle: Kernel handle from wave_load_kernel -/// stream: HIP stream pointer -/// grid_x, grid_y, grid_z: Grid dimensions -/// block_x, block_y, block_z: Block dimensions -/// args: Pointer to array of kernel argument pointers -/// num_args: Number of kernel arguments -/// -/// Returns: -/// 0 on success, non-zero on failure -int wave_launch_kernel(WaveKernelHandle handle, void *stream, uint32_t grid_x, - uint32_t grid_y, uint32_t grid_z, uint32_t block_x, - uint32_t block_y, uint32_t block_z, void **args, - size_t num_args); - -/// Unload a GPU kernel -/// -/// Args: -/// handle: Kernel handle from wave_load_kernel -void wave_unload_kernel(WaveKernelHandle handle); +void wave_launch_kernel(void *stream, void *function, int shared_memory_bytes, + int grid_x, int grid_y, int grid_z, int block_x, + int block_y, int block_z, void **args, int num_args); } From 3b40efd1184539df4746e0e321043f28f0803bbf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Nov 2025 22:12:16 +0100 Subject: [PATCH 31/77] load hip funcs explicitly Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/execution_engine.py | 3 +++ wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp | 4 +--- wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 37db19ffd..7dd1e4323 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -107,6 +107,9 @@ def _load_hip_runtime(): lib = _load_library("wave_hip_runtime") + # Load HIP functions eagerly + lib.load_functions() + symbol_map = {} # Register HIP runtime functions diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp index 2fd25dc52..4a2c68493 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp @@ -70,7 +70,7 @@ static void *get_symbol_address(module_handle_t module, } \ } while (0) -static void load_hip_functions() { +extern "C" void load_functions() { // Return early if already loaded if (hipModuleLaunchKernel && hipGetErrorName && hipGetErrorString && hipModuleUnload && hipModuleLoadData && hipModuleGetFunction) @@ -122,8 +122,6 @@ extern "C" void *wave_load_kernel(void *stream, void **cached_kernel_handle, const void *binary_pointer, size_t /*binary_size*/, const char *kernel_name) { - load_hip_functions(); - hipFunction_t function = *cached_kernel_handle; if (function) return function; diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h index 479c685b3..2b27dd3a4 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.h @@ -8,6 +8,7 @@ #include extern "C" { +void load_functions(); void *wave_load_kernel(void *stream, void **cached_kernel_handle, const void *binary_pointer, size_t binary_size, From cfe5a094222b01a8989d9b0049653db799261ba4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 00:47:46 +0100 Subject: [PATCH 32/77] scalar args Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/codegen/emitter.py | 1 + wave_lang/kernel/wave/compile.py | 9 ++++--- .../wave/execution_engine/buffer_utils.cpp | 24 +++++++++++++++++++ .../wave/execution_engine/buffer_utils.h | 8 +++++++ .../wave/execution_engine/execution_engine.py | 6 +++++ 5 files changed, 45 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/codegen/emitter.py b/wave_lang/kernel/wave/codegen/emitter.py index e236fcdd0..306aa8932 100644 --- a/wave_lang/kernel/wave/codegen/emitter.py +++ b/wave_lang/kernel/wave/codegen/emitter.py @@ -20,6 +20,7 @@ from wave_lang.kernel.ops.wave_ops import get_custom from wave_lang.kernel.lang import Memory from wave_lang.kernel.lang.kernel_buffer import KernelBuffer +from wave_lang.kernel.compiler.kernel_codegen import BindingType from wave_lang.kernel.compiler.utils import strides_from_symbolic_shape from wave_lang.kernel.lang.global_symbols import * from wave_lang.support.logging import get_logger diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index f3fa94e38..54350a2c9 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -3,6 +3,8 @@ from typing import Any, Optional, Callable, Sequence import torch +import ctypes +from ctypes import py_object from wave_lang.kernel.lang import IndexSymbol from wave_lang.support.ir_imports import Module, stream_d @@ -299,11 +301,10 @@ def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): # Create ctypes function type once # The host wrapper signature is: void func(void* stream, void* arg0, void* arg1, ...) - import ctypes num_kernel_args = len(self.options.kernel_usages) arg_types = [ctypes.c_void_p] + [ - ctypes.py_object + py_object ] * num_kernel_args # +1 for stream pointer func_type = ctypes.CFUNCTYPE(None, *arg_types) self._cfunc = func_type(self._host_func_ptr) @@ -315,12 +316,14 @@ def invoke(self, *args, **kwargs): """ Invokes the wave kernel with the given arguments using the ExecutionEngine. """ + + assert not kwargs, "kwargs are not supported" # Get the current CUDA stream stream_ptr = torch.cuda.current_stream().cuda_stream # Call the JIT-compiled host wrapper function # Signature: void func(void* stream, void* arg0, void* arg1, ...) - self._cfunc(stream_ptr, *args) + self._cfunc(stream_ptr, *(py_object(arg) for arg in args)) # Return None (kernel modifies output tensors in place) return None diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 8febe95a5..064eb30d3 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -76,3 +76,27 @@ extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, ret->sizes[0] = numel; ret->strides[0] = 1; } + +extern "C" int64_t wave_get_int64(PyObject *obj_ptr) { + GILState gil_state; + + int64_t value = PyLong_AsLongLong(obj_ptr); + if (PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error("wave_get_int64: Failed to convert to int64"); + } + + return value; +} + +extern "C" double wave_get_float64(PyObject *obj_ptr) { + GILState gil_state; + + double value = PyFloat_AsDouble(obj_ptr); + if (PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error("wave_get_float64: Failed to convert to double"); + } + + return value; +} diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index 23f045873..2089bfc28 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -36,4 +36,12 @@ extern "C" { /// This function assumes the PyObject is a PyTorch tensor and uses /// the PyTorch C API to extract the data pointer and size. void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, PyObject *obj); + +/// Extract an int64_t value from a PyObject. +/// Throws std::runtime_error if conversion fails. +int64_t wave_get_int64(PyObject *obj); + +/// Extract a double value from a PyObject. +/// Throws std::runtime_error if conversion fails. +double wave_get_float64(PyObject *obj); } diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 7dd1e4323..95afa156c 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -88,6 +88,12 @@ def _load_runtime_helpers(): ).value symbol_map["_mlir_ciface_wave_get_buffer"] = wave_get_buffer_addr + wave_get_int64_addr = ctypes.cast(lib.wave_get_int64, ctypes.c_void_p).value + symbol_map["wave_get_int64"] = wave_get_int64_addr + + wave_get_float64_addr = ctypes.cast(lib.wave_get_float64, ctypes.c_void_p).value + symbol_map["wave_get_float64"] = wave_get_float64_addr + return symbol_map From e6218cdc58f5eff16b7bbe5646c1115dbf1a3f82 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 13:05:25 +0100 Subject: [PATCH 33/77] fix test_dynamic_copy Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index b0e5e4b84..4e768d904 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -238,6 +238,7 @@ def test( canonicalize=True, run_bench=run_bench, use_buffer_ops=use_buffer_ops, + dynamic_symbols=[M, N], ) options = set_default_run_config(options) test = wave_compile(options, test) From 07350658854032b5a1061f525e8bca10e8aa3747 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 13:06:25 +0100 Subject: [PATCH 34/77] get_dim impl Signed-off-by: Ivan Butygin --- .../wave/execution_engine/buffer_utils.cpp | 36 +++++++++++++++++++ .../wave/execution_engine/buffer_utils.h | 14 ++++++++ .../wave/execution_engine/execution_engine.py | 3 ++ 3 files changed, 53 insertions(+) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 064eb30d3..ec48ba757 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -100,3 +100,39 @@ extern "C" double wave_get_float64(PyObject *obj_ptr) { return value; } + +extern "C" int64_t wave_get_dim(PyObject *obj_ptr, int32_t dim_idx) { + GILState gil_state; + + // Get tensor.size() method + PyObjectPtr size_method(PyObject_GetAttrString(obj_ptr, "size")); + if (!size_method) { + PyErr_Clear(); + throw std::runtime_error( + "wave_get_dim: Object does not have 'size' attribute"); + } + + // Call tensor.size(dim_idx) + PyObjectPtr dim_arg(PyLong_FromLong(dim_idx)); + if (!dim_arg) { + PyErr_Clear(); + throw std::runtime_error( + "wave_get_dim: Failed to create dimension argument"); + } + + PyObjectPtr size_result( + PyObject_CallOneArg(size_method.get(), dim_arg.get())); + if (!size_result) { + PyErr_Clear(); + throw std::runtime_error("wave_get_dim: Failed to call size()"); + } + + // Convert result to int64_t + int64_t dim_size = PyLong_AsLongLong(size_result.get()); + if (PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error("wave_get_dim: size() returned invalid value"); + } + + return dim_size; +} diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index 2089bfc28..740dd0562 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -44,4 +44,18 @@ int64_t wave_get_int64(PyObject *obj); /// Extract a double value from a PyObject. /// Throws std::runtime_error if conversion fails. double wave_get_float64(PyObject *obj); + +/// Extract the size of a specific dimension from a PyObject (PyTorch tensor). +/// +/// Args: +/// obj: PyObject* pointing to a PyTorch tensor +/// dim_idx: Dimension index to query (0-based) +/// +/// Returns: +/// Size of the specified dimension as int64_t +/// +/// Throws: +/// std::runtime_error if the object doesn't have a size() method or +/// if the dimension index is invalid +int64_t wave_get_dim(PyObject *obj, int32_t dim_idx); } diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 95afa156c..024e4ac73 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -94,6 +94,9 @@ def _load_runtime_helpers(): wave_get_float64_addr = ctypes.cast(lib.wave_get_float64, ctypes.c_void_p).value symbol_map["wave_get_float64"] = wave_get_float64_addr + wave_get_dim_addr = ctypes.cast(lib.wave_get_dim, ctypes.c_void_p).value + symbol_map["wave_get_dim"] = wave_get_dim_addr + return symbol_map From ae60dbb22dcf6ef9530351beffe7d8bc9afa766e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 13:39:05 +0100 Subject: [PATCH 35/77] erase original func Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/codegen/emitter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wave_lang/kernel/wave/codegen/emitter.py b/wave_lang/kernel/wave/codegen/emitter.py index 306aa8932..af6eb82a1 100644 --- a/wave_lang/kernel/wave/codegen/emitter.py +++ b/wave_lang/kernel/wave/codegen/emitter.py @@ -324,6 +324,8 @@ def abi_type(binding: BindingDesc): ): old_arg.replace_all_uses_with(new_value) + kernel_func.erase() + gpu_d.return_([]) kernel_func.erase() From 41d285643f5bde620c2e2c2b4bc69a1a53b244bd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 13:40:19 +0100 Subject: [PATCH 36/77] cleanup cmake Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/CMakeLists.txt | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index d46ed4f04..71ec8ecb2 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -6,12 +6,6 @@ cmake_minimum_required(VERSION 3.19...3.27) project(wave_execution_engine) -# Skip building on macOS -if(APPLE) - message(STATUS "Skipping wave_execution_engine build on ${CMAKE_SYSTEM_NAME}") - return() -endif() - # Set the C++ standard set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -32,7 +26,6 @@ execute_process( COMMAND_ERROR_IS_FATAL ANY) find_package(nanobind CONFIG REQUIRED) -# Find LLVM and MLIR find_package(LLVM REQUIRED CONFIG) find_package(MLIR REQUIRED CONFIG) @@ -46,22 +39,18 @@ include_directories(${MLIR_INCLUDE_DIRS}) separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS_LIST}) -# Compile buffer utils as a separate shared library add_library(wave_runtime_helpers SHARED buffer_utils.cpp ) -# Link Python for wave_runtime_helpers target_link_libraries(wave_runtime_helpers PRIVATE Python::Python ) -# Compile HIP runtime as a separate shared library add_library(wave_hip_runtime SHARED wave_hip_runtime.cpp ) -# Compile an extension library nanobind_add_module(wave_execution_engine NB_STATIC execution_engine.cpp bindings.cpp @@ -79,7 +68,6 @@ get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(mlir_extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) -# Link against LLVM and MLIR libraries target_link_libraries(wave_execution_engine PRIVATE Python::Python Python::Module From 7855c4cf46aca2a49451a7f1fef0d9839c707da5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 17 Nov 2025 22:01:03 +0100 Subject: [PATCH 37/77] more comments Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/codegen/emitter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/wave_lang/kernel/wave/codegen/emitter.py b/wave_lang/kernel/wave/codegen/emitter.py index af6eb82a1..306aa8932 100644 --- a/wave_lang/kernel/wave/codegen/emitter.py +++ b/wave_lang/kernel/wave/codegen/emitter.py @@ -324,8 +324,6 @@ def abi_type(binding: BindingDesc): ): old_arg.replace_all_uses_with(new_value) - kernel_func.erase() - gpu_d.return_([]) kernel_func.erase() From 7d7af5ad7d8cd698848556a051c9f354f3295cd2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 22 Nov 2025 21:37:14 +0100 Subject: [PATCH 38/77] install Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 71ec8ecb2..1a2ec6794 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -101,4 +101,4 @@ if(UNIX AND NOT APPLE) ) endif() -install(TARGETS wave_execution_engine wave_runtime_helpers wave_hip_runtime DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +install(TARGETS wave_execution_engine wave_runtime_helpers wave_hip_runtime LIBRARY DESTINATION .) From e466b9dbecad70a98b162e77ebbb0e7a4016716c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 22 Nov 2025 23:02:45 +0100 Subject: [PATCH 39/77] local water Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index f7ccef482..27610be3d 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -12,8 +12,9 @@ import subprocess import sys import math -from typing import Any, Sequence +from typing import Any, Sequence, Optional import importlib +from functools import lru_cache from wave_lang.support.ir_imports import ( Attribute, @@ -173,26 +174,36 @@ def replace_ops_and_collect_subspans(op: Operation) -> WalkResult: return local_module.get_asm(binary=False, print_generic_op_form=True) -def find_binary(name: str) -> str: +def find_local_water_binary_path(name: str) -> Optional[str]: this_path = Path(__file__).parent - path = this_path / "water_mlir" / "bin" / name - assert path.is_file(), f"Could not find the {name} executable at {path}" - return str(path) + tool_path = this_path / "water_mlir" / "bin" / name + if not tool_path.is_file() or not os.access(tool_path, os.X_OK): + return None + return str(tool_path) +@lru_cache def is_water_available() -> bool: - """Returns True if the water_mlir package is available.""" - try: - return ( - importlib.util.find_spec("wave_lang.kernel.wave.water_mlir.water_mlir") - is not None - ) - except Exception: - return False + """Returns True of the water_mlir package is available.""" + if (Path(__file__).parent / "water_mlir").exists(): + return True + + return importlib.util.find_spec("water_mlir") is not None +@lru_cache def get_water_binary_path() -> str: - return find_binary("water-opt") + path = find_local_water_binary_path("water-opt") + if path: + return path + + try: + from water_mlir import binaries as water_bin + except ImportError as err: + raise RuntimeError( + "optional water_mlir module not installed but its use is requested" + ) from err + return water_bin.find_binary("water-opt") def make_linear_pass_pipeline( From 225f4bf2afd0da68840431119e313777c4eb57ab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:00:08 +0100 Subject: [PATCH 40/77] tests Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 9 ++++++--- wave_lang/kernel/wave/water.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 4e768d904..3830ee9c4 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -138,8 +138,9 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") +@param_bool("use_water_pipeline", "water") @check_leaks -def test_copy(shape, use_buffer_ops, run_bench): +def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -183,6 +184,7 @@ def test( canonicalize=True, run_bench=run_bench, use_buffer_ops=use_buffer_ops, + use_water_pipeline=use_water_pipeline, ) options = set_default_run_config(options) test = wave_compile(options, test) @@ -194,7 +196,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") -def test_dynamic_copy(shape, use_buffer_ops, run_bench): +@param_bool("use_water_pipeline", "water") +def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -238,7 +241,7 @@ def test( canonicalize=True, run_bench=run_bench, use_buffer_ops=use_buffer_ops, - dynamic_symbols=[M, N], + use_water_pipeline=use_water_pipeline, ) options = set_default_run_config(options) test = wave_compile(options, test) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 27610be3d..8b148523e 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -398,6 +398,7 @@ def water_lowering_pipeline(module: Module, target_chip: str) -> Module: "cse", "loop-invariant-code-motion", "int-range-optimizations", + ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), ("convert-gpu-to-rocdl", {"use-bare-ptr-memref-call-conv": "1"}, "gpu.module"), ("rocdl-attach-target", {"chip": target_chip}), ("gpu-to-llvm", {"use-bare-pointers-for-kernels": "1"}), From 8db164885148a7487eb725be9cba7ee8a1cbcae0 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:03:14 +0100 Subject: [PATCH 41/77] refa coptions Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 2 +- wave_lang/kernel/wave/water.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 54350a2c9..2182317d3 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -525,7 +525,7 @@ def get_binary_path(): if options.use_water_pipeline: from .water import water_lowering_pipeline - module = water_lowering_pipeline(mb.module_op, options.target) + module = water_lowering_pipeline(mb.module_op, options) return WaveKernel2(options, module) elif not options.compile_to_mlir: diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 8b148523e..229d37735 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -16,6 +16,7 @@ import importlib from functools import lru_cache +from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.ir_imports import ( Attribute, BlockArgument, @@ -389,9 +390,10 @@ def diagnostic_from_json( print("[info] No out-of-bounds accesses detected.") -def water_lowering_pipeline(module: Module, target_chip: str) -> Module: +def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Module: binary = get_water_binary_path() mlir_asm = module.operation.get_asm() + target_chip = options.target pipeline = [ "lower-affine", "canonicalize", From 35e9ed6c616d4aaeb8cfbc5cd58d741654e395d2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:22:33 +0100 Subject: [PATCH 42/77] simplify execution_engine Signed-off-by: Ivan Butygin --- .../execution_engine/execution_engine.cpp | 117 +++--------------- 1 file changed, 17 insertions(+), 100 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index d7640a24a..756c0d97c 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -41,103 +42,13 @@ #define DEBUG_TYPE "wave-execution-engine" -// Ensure LLVM native target is initialized only once -static std::once_flag llvmInitFlag; - static void initializeLLVMTarget() { - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); -} - -static llvm::OptimizationLevel mapToLevel(llvm::CodeGenOptLevel level) { - unsigned optimizeSize = 0; // TODO: unhardcode - - switch (level) { - default: - llvm_unreachable("Invalid optimization level!"); - - case llvm::CodeGenOptLevel::None: - return llvm::OptimizationLevel::O0; - - case llvm::CodeGenOptLevel::Less: - return llvm::OptimizationLevel::O1; - - case llvm::CodeGenOptLevel::Default: - switch (optimizeSize) { - default: - llvm_unreachable("Invalid optimization level for size!"); - - case 0: - return llvm::OptimizationLevel::O2; - - case 1: - return llvm::OptimizationLevel::Os; - - case 2: - return llvm::OptimizationLevel::Oz; - } - - case llvm::CodeGenOptLevel::Aggressive: - return llvm::OptimizationLevel::O3; - } -} - -static llvm::PipelineTuningOptions -getPipelineTuningOptions(llvm::CodeGenOptLevel optLevelVal) { - llvm::PipelineTuningOptions pto; - auto level = static_cast(optLevelVal); - - pto.LoopUnrolling = level > 0; - pto.LoopVectorization = level > 1; - pto.SLPVectorization = level > 1; - return pto; -} - -static void runOptimizationPasses(llvm::Module &M, llvm::TargetMachine &TM) { - llvm::CodeGenOptLevel optLevelVal = TM.getOptLevel(); - - llvm::LoopAnalysisManager lam; - llvm::FunctionAnalysisManager fam; - llvm::CGSCCAnalysisManager cgam; - llvm::ModuleAnalysisManager mam; - - llvm::PassInstrumentationCallbacks pic; - llvm::PrintPassOptions ppo; - ppo.Indent = false; - ppo.SkipAnalyses = false; - llvm::StandardInstrumentations si(M.getContext(), /*debugLogging*/ false, - /*verifyEach*/ true, ppo); - - si.registerCallbacks(pic, &mam); - - llvm::PassBuilder pb(&TM, getPipelineTuningOptions(optLevelVal)); - - llvm::ModulePassManager mpm; - - if (/*verify*/ true) { - pb.registerPipelineStartEPCallback( - [&](llvm::ModulePassManager &mpm, llvm::OptimizationLevel level) { - mpm.addPass(createModuleToFunctionPassAdaptor(llvm::VerifierPass())); - }); - } - - // Register all the basic analyses with the managers. - pb.registerModuleAnalyses(mam); - pb.registerCGSCCAnalyses(cgam); - pb.registerFunctionAnalyses(fam); - pb.registerLoopAnalyses(lam); - pb.crossRegisterProxies(lam, fam, cgam, mam); - - llvm::OptimizationLevel level = mapToLevel(optLevelVal); - - if (optLevelVal == llvm::CodeGenOptLevel::None) { - mpm = pb.buildO0DefaultPipeline(level); - } else { - mpm = pb.buildPerModuleDefaultPipeline(level); - } - - mpm.run(M, mam); + static bool initOnce = []() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + return true; + }(); } /// A simple object cache following Lang's LLJITWithObjectCache example. @@ -209,14 +120,17 @@ static void setupModule(llvm::Module &M, llvm::TargetMachine &TM) { namespace { class CustomCompiler : public llvm::orc::SimpleCompiler { public: + using Optimizer = std::function; using Transformer = std::function; using AsmPrinter = std::function; CustomCompiler(Transformer t, AsmPrinter a, std::unique_ptr TM, llvm::ObjectCache *ObjCache = nullptr) - : SimpleCompiler(*TM, ObjCache), TM(std::move(TM)), - transformer(std::move(t)), printer(std::move(a)) {} + : SimpleCompiler(*TM, ObjCache), + optimizer(mlir::makeOptimizingTransformer(/*opLevel*/ 3, + /*sizeLevel*/ 0, TM.get())), + TM(std::move(TM)), transformer(std::move(t)), printer(std::move(a)) {} llvm::Expected operator()(llvm::Module &M) override { if (transformer) { @@ -226,7 +140,9 @@ class CustomCompiler : public llvm::orc::SimpleCompiler { } setupModule(M, *TM); - runOptimizationPasses(M, *TM); + auto error = optimizer(&M); + if (error) + return error; if (printer) { llvm::SmallVector buffer; @@ -245,6 +161,7 @@ class CustomCompiler : public llvm::orc::SimpleCompiler { } private: + Optimizer optimizer; std::shared_ptr TM; Transformer transformer; AsmPrinter printer; @@ -266,7 +183,7 @@ wave::ExecutionEngine::ExecutionEngine(const ExecutionEngineOptions &options) } // Initialize LLVM native target (only once per process) - std::call_once(llvmInitFlag, initializeLLVMTarget); + initializeLLVMTarget(); auto tmBuilder = llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); From d8d3fac0fe480fe5ac70c0d3560b5a7a075a5ffe Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:28:31 +0100 Subject: [PATCH 43/77] opt_pass Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 229d37735..18a3dd746 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -394,10 +394,13 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu binary = get_water_binary_path() mlir_asm = module.operation.get_asm() target_chip = options.target + opt_pass = "composite-fixed-point-pass", { + "name": "CompositePass", + "pipeline": "any(canonicalize,cse)", + } pipeline = [ "lower-affine", - "canonicalize", - "cse", + opt_pass, "loop-invariant-code-motion", "int-range-optimizations", ("convert-amdgpu-to-rocdl", {"chipset": target_chip}), @@ -405,13 +408,11 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu ("rocdl-attach-target", {"chip": target_chip}), ("gpu-to-llvm", {"use-bare-pointers-for-kernels": "1"}), "reconcile-unrealized-casts", - "canonicalize", - "cse", + opt_pass, "gpu-module-to-binary", "water-gpu-to-gpu-runtime", "symbol-dce", - "canonicalize", - "cse", + opt_pass, ] try: result = subprocess.check_output( From 342601e5e61fee0e6756f67e31fe0d3b3f575acf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:47:30 +0100 Subject: [PATCH 44/77] print IR after all Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 18a3dd746..962e3a399 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -414,9 +414,13 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu "symbol-dce", opt_pass, ] + args = [binary, make_linear_pass_pipeline(pipeline)] + if options.mlir_print_ir_after_all: + args.append("--mlir-print-ir-after-all") + try: result = subprocess.check_output( - [binary, make_linear_pass_pipeline(pipeline)], + args, input=mlir_asm, text=True, ) From d8dbefcca83ea7547c411417818b5c9cd3d2d1ef Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 02:56:36 +0100 Subject: [PATCH 45/77] rename class Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 2182317d3..6bf34ef8b 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -266,9 +266,12 @@ def __call__(self, *args, **kwargs): return invoke_with_profile(self.options, self.invoke, *args, **kwargs) -class WaveKernel2: - def __init__(self, options: WaveCompileOptions, module: Module | bytes | str): +class WaveKernelExecutionEngine: + def __init__( + self, options: WaveCompileOptions, module: Module | bytes | str, mlir_asm: str + ): self.options = options + self.asm = mlir_asm self._engine = None self._module_handle = None @@ -318,15 +321,14 @@ def invoke(self, *args, **kwargs): """ assert not kwargs, "kwargs are not supported" - # Get the current CUDA stream + # Get the current stream stream_ptr = torch.cuda.current_stream().cuda_stream # Call the JIT-compiled host wrapper function # Signature: void func(void* stream, void* arg0, void* arg1, ...) self._cfunc(stream_ptr, *(py_object(arg) for arg in args)) - # Return None (kernel modifies output tensors in place) - return None + return self.asm def __del__(self): """Clean up the loaded module when the kernel is destroyed.""" @@ -526,7 +528,7 @@ def get_binary_path(): from .water import water_lowering_pipeline module = water_lowering_pipeline(mb.module_op, options) - return WaveKernel2(options, module) + return WaveKernelExecutionEngine(options, module, asm) elif not options.compile_to_mlir: # LLVM flow: only compile to VMFB when not in MLIR-only mode From 4f2bce4cabed975eea2b63d34a13f71fc747df77 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 03:10:49 +0100 Subject: [PATCH 46/77] test Signed-off-by: Ivan Butygin --- tests/kernel/wave/common/utils.py | 2 +- tests/kernel/wave/wave_e2e_test.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index e7c7355c7..e31f9972c 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -65,4 +65,4 @@ def param_bool(name, shortname=None, values=None): shortname = shortname or name values = values or [False, True] ids = [f"{shortname}" if v else f"no_{shortname}" for v in values] - return pytest.mark.parametrize(name, [pytest.param(v) for v in values], ids=ids) + return pytest.mark.parametrize(name, values, ids=ids) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 3830ee9c4..1e96f3aef 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -50,6 +50,7 @@ require_rdna4, ) from .common.shapes import get_test_shapes as get_common_test_shape +from wave_lang.kernel.wave.water import is_water_available default_test_shapes = [ @@ -196,7 +197,14 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") -@param_bool("use_water_pipeline", "water") +@param_bool( + "use_water_pipeline", + "water", + values=[ + False, + pytest.param(True, marks=pytest.mark.skipif(not is_water_available())), + ], +) def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M N = tkl.sym.N From 29c68c6fdceb02bf18cd19dc749108332909379b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 10:09:08 +0100 Subject: [PATCH 47/77] fix test Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_e2e_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 1e96f3aef..a13cfc4af 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -136,10 +136,17 @@ def test( assert os.path.exists(vmfb_file) +_need_water = pytest.mark.skipif( + not is_water_available(), reason="Water MLIR package not installed." +) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") -@param_bool("use_water_pipeline", "water") +@param_bool( + "use_water_pipeline", "water", values=[False, pytest.param(True, marks=_need_water)] +) @check_leaks def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M @@ -198,12 +205,7 @@ def test( @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") @param_bool( - "use_water_pipeline", - "water", - values=[ - False, - pytest.param(True, marks=pytest.mark.skipif(not is_water_available())), - ], + "use_water_pipeline", "water", values=[False, pytest.param(True, marks=_need_water)] ) def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M From 460206e4c7b0c35f5ddd97117348b4d0d8faa093 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 10:11:55 +0100 Subject: [PATCH 48/77] register dialect Signed-off-by: Ivan Butygin --- water/tools/water-opt/water-opt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index 125349387..430d9dd5d 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -50,6 +50,7 @@ int main(int argc, char **argv) { mlir::registerAllGPUToLLVMIRTranslations(registry); + registry.insert(); mlir::water::test::registerWaterTestDialect(registry); return mlir::asMainReturnCode( From 18ace68dfeb047b7e1bec0d3ae13bb9b1664aeb7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 10:49:41 +0100 Subject: [PATCH 49/77] lib Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 1a2ec6794..db60dcedb 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -75,6 +75,7 @@ target_link_libraries(wave_execution_engine PRIVATE LLVM${LLVM_NATIVE_ARCH}AsmParser LLVM${LLVM_NATIVE_ARCH}CodeGen LLVM${LLVM_NATIVE_ARCH}Desc + LLVMExecutionEngine LLVMOrcJIT LLVMTarget MLIRCAPIDebug From b8f1ba1f77130a2ec86fad530330e4642565a84d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 19:28:06 +0100 Subject: [PATCH 50/77] skip shared libs Signed-off-by: Ivan Butygin --- .github/workflows/ci-gpu.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-gpu.yaml b/.github/workflows/ci-gpu.yaml index 8189a4959..805e32a77 100644 --- a/.github/workflows/ci-gpu.yaml +++ b/.github/workflows/ci-gpu.yaml @@ -237,7 +237,10 @@ jobs: fail-fast: false matrix: name: [linux-mi325-1gpu-ossci-iree-org] - shared_libs: ["ON", "OFF"] + shared_libs: ["OFF"] + # TODO: we are linking water shared libs with static LLVM libraries, + # which doesn't really work if multiple of them linked with LLVM proper. + # shared_libs: ["ON", "OFF"] runs-on: ${{ matrix.name }} timeout-minutes: 60 needs: build_llvm_linux From 57b6a2f4d74f61dae8c09b275d8337de80df349f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 20:57:06 +0100 Subject: [PATCH 51/77] fix tests Signed-off-by: Ivan Butygin --- tests/kernel/wave/common/utils.py | 12 ++++++++++++ .../wave/test_execution_engine_wrapper.py | 18 +++++++----------- tests/kernel/wave/wave_e2e_test.py | 19 +++++++++---------- .../kernel/wave/execution_engine/__init__.py | 11 +++-------- .../wave/execution_engine/execution_engine.py | 4 ++++ 5 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index e31f9972c..572c2b967 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -66,3 +66,15 @@ def param_bool(name, shortname=None, values=None): values = values or [False, True] ids = [f"{shortname}" if v else f"no_{shortname}" for v in values] return pytest.mark.parametrize(name, values, ids=ids) + + +def _is_water_lowering_available() -> bool: + from wave_lang.kernel.wave.water import is_water_available + from wave_lang.kernel.wave.execution_engine import is_execution_engine_available + + return is_water_available() and is_execution_engine_available() + + +reguire_water_lowering = pytest.mark.skipif( + not _is_water_lowering_available(), reason="Water lowering is not available." +) diff --git a/tests/kernel/wave/test_execution_engine_wrapper.py b/tests/kernel/wave/test_execution_engine_wrapper.py index e8af0b5f5..56987776b 100644 --- a/tests/kernel/wave/test_execution_engine_wrapper.py +++ b/tests/kernel/wave/test_execution_engine_wrapper.py @@ -12,19 +12,15 @@ import pytest import weakref -try: - from wave_lang.kernel.wave.execution_engine import ( - get_execution_engine, - clear_engine_cache, - is_engine_cached, - ) - - EXECUTION_ENGINE_AVAILABLE = True -except ImportError: - EXECUTION_ENGINE_AVAILABLE = False +from wave_lang.kernel.wave.execution_engine import ( + clear_engine_cache, + get_execution_engine, + is_engine_cached, + is_execution_engine_available, +) pytestmark = pytest.mark.skipif( - not EXECUTION_ENGINE_AVAILABLE, + not is_execution_engine_available(), reason="ExecutionEngine not available (C++ extension not built)", ) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index a13cfc4af..f45a434f2 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -44,13 +44,13 @@ param_bool, perf_test, require_cdna3, - require_e2e, - require_cdna_2_or_3_or_4, require_cdna4, + require_cdna_2_or_3_or_4, + require_e2e, require_rdna4, + reguire_water_lowering, ) from .common.shapes import get_test_shapes as get_common_test_shape -from wave_lang.kernel.wave.water import is_water_available default_test_shapes = [ @@ -136,16 +136,13 @@ def test( assert os.path.exists(vmfb_file) -_need_water = pytest.mark.skipif( - not is_water_available(), reason="Water MLIR package not installed." -) - - @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") @param_bool( - "use_water_pipeline", "water", values=[False, pytest.param(True, marks=_need_water)] + "use_water_pipeline", + "water", + values=[False, pytest.param(True, marks=reguire_water_lowering)], ) @check_leaks def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): @@ -205,7 +202,9 @@ def test( @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) @param_bool("use_buffer_ops", "buf_ops") @param_bool( - "use_water_pipeline", "water", values=[False, pytest.param(True, marks=_need_water)] + "use_water_pipeline", + "water", + values=[False, pytest.param(True, marks=reguire_water_lowering)], ) def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M diff --git a/wave_lang/kernel/wave/execution_engine/__init__.py b/wave_lang/kernel/wave/execution_engine/__init__.py index 1378bf122..7a5f292b3 100644 --- a/wave_lang/kernel/wave/execution_engine/__init__.py +++ b/wave_lang/kernel/wave/execution_engine/__init__.py @@ -12,18 +12,12 @@ variables. """ -# Import C++ bindings (may not be available if not built yet) -try: - from wave_execution_engine import ExecutionEngine, ExecutionEngineOptions -except ImportError: - ExecutionEngine = None - ExecutionEngineOptions = None - # Import Python wrapper with caching from .execution_engine import ( - get_execution_engine, clear_engine_cache, + get_execution_engine, is_engine_cached, + is_execution_engine_available, ) __all__ = [ @@ -31,6 +25,7 @@ "ExecutionEngine", "ExecutionEngineOptions", # Python wrapper + "is_execution_engine_available", "get_execution_engine", "clear_engine_cache", "is_engine_cached", diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 024e4ac73..abae20fa1 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -23,6 +23,10 @@ ExecutionEngineOptions = None +def is_execution_engine_available() -> bool: + return ExecutionEngine is not None and ExecutionEngineOptions is not None + + # Global weak reference to the cached ExecutionEngine instance _cached_engine: Optional[weakref.ref] = None From 2b77f8984dc7baee5700a2a3746db1091857f8aa Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 21:26:46 +0100 Subject: [PATCH 52/77] lit test Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/water_lowering.py | 110 ++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 lit_tests/kernel/wave/water_lowering.py diff --git a/lit_tests/kernel/wave/water_lowering.py b/lit_tests/kernel/wave/water_lowering.py new file mode 100644 index 000000000..89c56b12b --- /dev/null +++ b/lit_tests/kernel/wave/water_lowering.py @@ -0,0 +1,110 @@ +# RUN: python %s | FileCheck %s + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.general_utils import ( + run_test, +) +from wave_lang.support.location_config import ( + LocationCaptureConfig, + LocationCaptureLevel, +) + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +B = tkl.sym.B +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +BLOCK_B = tkl.sym.BLOCK_B +LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD +STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + +def get_wave_compile_options( + canonicalize: bool = False, + dynamic_symbols=[], + additional_symbols={}, + location_capture_config=LocationCaptureConfig( + level=LocationCaptureLevel.FILE_LINE_COL + ), + drop_debug_info_before_mlir=True, +): + bindings = { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + } + bindings.update(additional_symbols) + + # Remove dynamic symbols from the bindings. + for sym in dynamic_symbols: + if sym in bindings: + del bindings[sym] + + return WaveCompileOptions( + subs=bindings, + canonicalize=canonicalize, + dynamic_symbols=dynamic_symbols, + compile_to_mlir=True, + location_capture_config=location_capture_config, + drop_debug_info_before_mlir=drop_debug_info_before_mlir, + use_water_pipeline=True, + ) + + +@run_test +def test_read_write(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: 16, N: 16}) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def read_write( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(a) + tkw.write(res, b) + + read_write = wave_compile(get_wave_compile_options(canonicalize=True), read_write) + print(read_write.asm) + + # CHECK-LABEL: test_read_write + # CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 - (s0 floordiv 64) * 48)> + # CHECK: gpu.module @gpu_module + # CHECK: gpu.func @read_write + # CHECK-SAME: (%[[D0:.*]]: memref, %[[D1:.*]]: memref) + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[thread_id_x:.*]] = gpu.thread_id x + # CHECK: %[[S0:.*]] = memref.reinterpret_cast %[[D0]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> + # CHECK: %[[S1:.*]] = memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> + # CHECK: %[[I0:.*]] = affine.apply #[[MAP0]]()[%[[thread_id_x]]] + # CHECK: %[[V:.*]] = vector.load %[[S0]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> + # CHECK: vector.store %[[V]], %[[S1]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> + # CHECK: return + + # CHECK-LABEL: func.func @isolated_benchmark + # CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr) + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + # CHECK: %[[BUF1:.*]] = call @wave_get_buffer(%[[ARG1]]) : (!llvm.ptr) -> memref + # CHECK: %[[VIEW1:.*]] = memref.view %[[BUF1]][%[[C0]]][] : memref to memref + # CHECK: %[[BUF2:.*]] = call @wave_get_buffer(%[[ARG2]]) : (!llvm.ptr) -> memref + # CHECK: %[[VIEW2:.*]] = memref.view %[[BUF2]][%[[C0]]][] : memref to memref + # CHECK: gpu.launch_func @gpu_module::@read_write blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C64]], %[[C1]], %[[C1]]) args(%[[VIEW1]] : memref, %[[VIEW2]] : memref) + # CHECK: return From d714cf7c26f99327afe36ec262663f912f224b4f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 21:32:17 +0100 Subject: [PATCH 53/77] REQUIRES Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/water_lowering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lit_tests/kernel/wave/water_lowering.py b/lit_tests/kernel/wave/water_lowering.py index 89c56b12b..057a7d350 100644 --- a/lit_tests/kernel/wave/water_lowering.py +++ b/lit_tests/kernel/wave/water_lowering.py @@ -1,3 +1,4 @@ +# REQUIRES: water # RUN: python %s | FileCheck %s import wave_lang.kernel.lang as tkl From c3edecc6c33f45c93ad0ac5c2287b0c423ed44f9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 22:20:56 +0100 Subject: [PATCH 54/77] fix typo Signed-off-by: Ivan Butygin --- tests/kernel/wave/common/utils.py | 2 +- tests/kernel/wave/wave_e2e_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index 572c2b967..e3137c63b 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -75,6 +75,6 @@ def _is_water_lowering_available() -> bool: return is_water_available() and is_execution_engine_available() -reguire_water_lowering = pytest.mark.skipif( +require_water_lowering = pytest.mark.skipif( not _is_water_lowering_available(), reason="Water lowering is not available." ) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index f45a434f2..a18b8f37b 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -48,7 +48,7 @@ require_cdna_2_or_3_or_4, require_e2e, require_rdna4, - reguire_water_lowering, + require_water_lowering, ) from .common.shapes import get_test_shapes as get_common_test_shape @@ -142,7 +142,7 @@ def test( @param_bool( "use_water_pipeline", "water", - values=[False, pytest.param(True, marks=reguire_water_lowering)], + values=[False, pytest.param(True, marks=require_water_lowering)], ) @check_leaks def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): @@ -204,7 +204,7 @@ def test( @param_bool( "use_water_pipeline", "water", - values=[False, pytest.param(True, marks=reguire_water_lowering)], + values=[False, pytest.param(True, marks=require_water_lowering)], ) def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M From c8135b0a85083814629bf69feff0eac39b5d5662 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 23 Nov 2025 22:29:45 +0100 Subject: [PATCH 55/77] typos and cleanup Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 10 ++++++---- wave_lang/kernel/wave/execution_engine/__init__.py | 2 ++ .../kernel/wave/execution_engine/execution_engine.h | 12 ++++++------ wave_lang/kernel/wave/water.py | 3 ++- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 6bf34ef8b..9f05ebbf2 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -330,10 +330,12 @@ def invoke(self, *args, **kwargs): return self.asm - def __del__(self): - """Clean up the loaded module when the kernel is destroyed.""" - if self._module_handle is not None and self._engine is not None: - self._engine.release_module(self._module_handle) + # TODO: __del__ call order is not guaranteed, need a better way to clean up + # the loaded module when the kernel is destroyed. + # def __del__(self): + # """Clean up the loaded module when the kernel is destroyed.""" + # if self._module_handle is not None and self._engine is not None: + # self._engine.release_module(self._module_handle) def wave_compile( diff --git a/wave_lang/kernel/wave/execution_engine/__init__.py b/wave_lang/kernel/wave/execution_engine/__init__.py index 7a5f292b3..05798d4a2 100644 --- a/wave_lang/kernel/wave/execution_engine/__init__.py +++ b/wave_lang/kernel/wave/execution_engine/__init__.py @@ -14,6 +14,8 @@ # Import Python wrapper with caching from .execution_engine import ( + ExecutionEngine, + ExecutionEngineOptions, clear_engine_cache, get_execution_engine, is_engine_cached, diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h index 5e1454832..66868f124 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.h +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -40,11 +40,11 @@ struct ExecutionEngineOptions { /// be dumped to a file via the `dumpToObjectfile` method. bool enableObjectCache = false; - /// If enable `enableGDBNotificationListener` is set, the JIT compiler will + /// If `enableGDBNotificationListener` is true, the JIT compiler will /// notify the llvm's global GDB notification listener. bool enableGDBNotificationListener = true; - /// If `enablePerfNotificationListener` is set, the JIT compiler will notify + /// If `enablePerfNotificationListener` is true, the JIT compiler will notify /// the llvm's global Perf notification listener. bool enablePerfNotificationListener = true; @@ -60,7 +60,7 @@ struct ExecutionEngineOptions { /// optimization. std::function lateTransformer; - /// If `asmPrinter` is provided, it will be called to print resulted assembly + /// If `asmPrinter` is provided, it will be called to print resulting assembly /// just before final code generation. std::function asmPrinter; }; @@ -74,8 +74,8 @@ class ExecutionEngine { ExecutionEngine(const ExecutionEngineOptions &options); ~ExecutionEngine(); - /// Compiles given module, adds it to execution engine and run its contructors - /// if any. + /// Compiles given module, adds it to execution engine and run its + /// constructors if any. llvm::Expected loadModule(mlir::ModuleOp m); /// Deserializes MLIR bytecode from a memory buffer, compiles it, and loads @@ -87,7 +87,7 @@ class ExecutionEngine { /// execution engine. llvm::Expected loadModuleFromText(llvm::StringRef mlirText); - /// Runs module desctructors and removes it from execution engine. + /// Runs module destructors and removes it from execution engine. void releaseModule(ModuleHandle handle); /// Looks up the original function with the given name and returns a diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 962e3a399..5ba900ab7 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -185,7 +185,7 @@ def find_local_water_binary_path(name: str) -> Optional[str]: @lru_cache def is_water_available() -> bool: - """Returns True of the water_mlir package is available.""" + """Returns True if the water_mlir package is available.""" if (Path(__file__).parent / "water_mlir").exists(): return True @@ -423,6 +423,7 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu args, input=mlir_asm, text=True, + stderr=subprocess.PIPE, ) except subprocess.CalledProcessError as e: print(e.stderr) From debc49e3b088d54163857a96b57acee30a90c0ff Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 01:23:22 +0100 Subject: [PATCH 56/77] runtime pass lit Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 5ba900ab7..70ed6edca 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -423,7 +423,6 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu args, input=mlir_asm, text=True, - stderr=subprocess.PIPE, ) except subprocess.CalledProcessError as e: print(e.stderr) From 3fcf325d3b13d0620fb3f0123aa7810d58c805ee Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 02:05:13 +0100 Subject: [PATCH 57/77] typos Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/execution_engine.cpp | 2 +- wave_lang/kernel/wave/water.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index 756c0d97c..70582d645 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -363,7 +363,7 @@ wave::ExecutionEngine::lookup(wave::ExecutionEngine::ModuleHandle handle, auto expectedSymbol = jit->lookup(*dylib, name); // JIT lookup may return an Error referring to strings stored internally by - // the JIT. If the Error outlives the ExecutionEngine, it would want have a + // the JIT. If the Error outlives the ExecutionEngine, it would have a // dangling reference, which is currently caught by an assertion inside JIT // thanks to hand-rolled reference counting. Rewrap the error message into a // string before returning. Alternatively, ORC JIT should consider copying diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 70ed6edca..c7c29977a 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -213,7 +213,7 @@ def make_linear_pass_pipeline( ], ) -> str: def make_pass_arguments( - name: str, args: dict[str, Any], module_name: str = None + name: str, args: dict[str, Any], module_name: Optional[str] = None ) -> str: ret = ( name @@ -253,7 +253,7 @@ def water_leak_in_bounds_check(module: Module, override_ir: str = ""): ] def get_code_context( - filename: str, start_line: int, end_line, context: int = 2 + filename: str, start_line: int, end_line: int, context: int = 2 ) -> str: """ Retrieves a line and a few lines of context around it. From 67bf7edd160e9d5ad0b1eb4ffe557b72cf7940ca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 02:12:12 +0100 Subject: [PATCH 58/77] fix python Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index db60dcedb..39c86629f 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -10,8 +10,6 @@ project(wave_execution_engine) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -find_package(Python 3.10 COMPONENTS Interpreter Development.Module REQUIRED) - if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") From ab448006cc72298e92a47e863a02047b5bb19b8b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 13:06:50 +0100 Subject: [PATCH 59/77] if(auto error = ...) Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/execution_engine.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index 70582d645..697509810 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -134,14 +134,12 @@ class CustomCompiler : public llvm::orc::SimpleCompiler { llvm::Expected operator()(llvm::Module &M) override { if (transformer) { - auto err = transformer(M); - if (err) - return err; + if (auto error = transformer(M)) + return error; } setupModule(M, *TM); - auto error = optimizer(&M); - if (error) + if (auto error = optimizer(&M)) return error; if (printer) { From 3b414ab0a181f506522530fdb668df2b677ad4d2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 13:56:08 +0100 Subject: [PATCH 60/77] require_water_and_ee Signed-off-by: Ivan Butygin --- tests/kernel/wave/common/utils.py | 7 ++++--- tests/kernel/wave/wave_e2e_test.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernel/wave/common/utils.py b/tests/kernel/wave/common/utils.py index e3137c63b..144512940 100644 --- a/tests/kernel/wave/common/utils.py +++ b/tests/kernel/wave/common/utils.py @@ -68,13 +68,14 @@ def param_bool(name, shortname=None, values=None): return pytest.mark.parametrize(name, values, ids=ids) -def _is_water_lowering_available() -> bool: +def _is_water_and_ee_available() -> bool: from wave_lang.kernel.wave.water import is_water_available from wave_lang.kernel.wave.execution_engine import is_execution_engine_available return is_water_available() and is_execution_engine_available() -require_water_lowering = pytest.mark.skipif( - not _is_water_lowering_available(), reason="Water lowering is not available." +require_water_and_ee = pytest.mark.skipif( + not _is_water_and_ee_available(), + reason="Water or execution engine are not available.", ) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index a18b8f37b..1fdc2b6af 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -48,7 +48,7 @@ require_cdna_2_or_3_or_4, require_e2e, require_rdna4, - require_water_lowering, + require_water_and_ee, ) from .common.shapes import get_test_shapes as get_common_test_shape @@ -142,7 +142,7 @@ def test( @param_bool( "use_water_pipeline", "water", - values=[False, pytest.param(True, marks=require_water_lowering)], + values=[False, pytest.param(True, marks=require_water_and_ee)], ) @check_leaks def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): @@ -204,7 +204,7 @@ def test( @param_bool( "use_water_pipeline", "water", - values=[False, pytest.param(True, marks=require_water_lowering)], + values=[False, pytest.param(True, marks=require_water_and_ee)], ) def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline): M = tkl.sym.M From 5a491c3c878e637f08308745e42288ff59ae8613 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 15:39:07 +0100 Subject: [PATCH 61/77] cleanup Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 9f05ebbf2..1de61f14c 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -17,6 +17,7 @@ get_temp_binary_dir, is_cache_enabled, ) +from .water import water_lowering_pipeline from .compile_options import WaveCompileOptions from .utils.compile_utils import compile_to_vmfb from .utils.general_utils import wave_dtype_to_torch @@ -281,16 +282,16 @@ def __init__( # TODO: investigate why bytecode deserialization is not working if isinstance(module, (bytes, str)): # Assume it's already MLIR text - mlir_asm = module.decode() if isinstance(module, bytes) else module + optimized_mlir = module.decode() if isinstance(module, bytes) else module else: # Serialize the MLIR module to text - mlir_asm = str(module) + optimized_mlir = str(module) - # Load module eagerly + # Get the execution engine instance and load the module from wave_lang.kernel.wave.execution_engine import get_execution_engine self._engine = get_execution_engine() - self._module_handle = self._engine.load_module_from_text(mlir_asm) + self._module_handle = self._engine.load_module_from_text(optimized_mlir) # Look up the host wrapper function func_name = self.options.func_name @@ -302,8 +303,8 @@ def __init__( f"Make sure the module was compiled with emit_host_func. Error: {e}" ) - # Create ctypes function type once - # The host wrapper signature is: void func(void* stream, void* arg0, void* arg1, ...) + # Create ctypes function type + # The host wrapper signature is: void func(void* stream, PyObject* arg0, PyObject* arg1, ...) num_kernel_args = len(self.options.kernel_usages) arg_types = [ctypes.c_void_p] + [ @@ -312,20 +313,18 @@ def __init__( func_type = ctypes.CFUNCTYPE(None, *arg_types) self._cfunc = func_type(self._host_func_ptr) - def __call__(self, *args, **kwargs): - return self.invoke(*args, **kwargs) + def __call__(self, *args): + return self.invoke(*args) - def invoke(self, *args, **kwargs): + def invoke(self, *args): """ Invokes the wave kernel with the given arguments using the ExecutionEngine. """ - - assert not kwargs, "kwargs are not supported" # Get the current stream stream_ptr = torch.cuda.current_stream().cuda_stream # Call the JIT-compiled host wrapper function - # Signature: void func(void* stream, void* arg0, void* arg1, ...) + # Signature: void func(void* stream, PyObject* arg0, PyObject* arg1, ...) self._cfunc(stream_ptr, *(py_object(arg) for arg in args)) return self.asm @@ -527,8 +526,6 @@ def get_binary_path(): if options.backend == "asm" and not options.compile_to_asm: _compile_asm_to_binary(asm, options) if options.use_water_pipeline: - from .water import water_lowering_pipeline - module = water_lowering_pipeline(mb.module_op, options) return WaveKernelExecutionEngine(options, module, asm) From 4e6b9237e1b9d0caec9df3308dd2f593f02d3c70 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 15:47:15 +0100 Subject: [PATCH 62/77] make_linear_pass_pipeline doc Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index c7c29977a..5558a5db1 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -212,8 +212,21 @@ def make_linear_pass_pipeline( tuple[str, dict[str, Any]] | tuple[str, dict[str, Any], str] | str ], ) -> str: + """ + Construct a pass pipeline string for mlir-opt style tool. + + Args: + pipeline: A sequence of pass names and arguments. + - For the pass with no arguments/all default arguments, pass just the name as a string. + - For the pass with arguments, pass a tuple with the name and a dictionary of arguments. + - For the pass with a root op, pass a tuple with the name, a dictionary of arguments, and the root op name. + Arguments dict can be empty. + Returns: + A string representing the pass pipeline command line argument. + """ + def make_pass_arguments( - name: str, args: dict[str, Any], module_name: Optional[str] = None + name: str, args: dict[str, Any], root_op: Optional[str] = None ) -> str: ret = ( name @@ -221,8 +234,8 @@ def make_pass_arguments( + " ".join("=".join((key, str(value))) for (key, value) in args.items()) + "}" ) - if module_name: - ret = module_name + "(" + ret + ")" + if root_op: + ret = root_op + "(" + ret + ")" return ret return ( From 43453f230e168968406a06c03e8c5a61a99f6cfc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 15:49:14 +0100 Subject: [PATCH 63/77] unlocal imports Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/execution_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index abae20fa1..4d75fd718 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -15,6 +15,10 @@ import weakref from typing import Optional +import ctypes +import platform +from pathlib import Path + try: from .wave_execution_engine import ExecutionEngine, ExecutionEngineOptions except ImportError: @@ -44,9 +48,6 @@ def _load_library(lib_basename: str): Raises: RuntimeError: If the library cannot be found or loaded """ - import ctypes - import platform - from pathlib import Path # Find the library file lib_name = { From ab767258486dd87799eb3d054c9197649a5f8ecc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 16:08:51 +0100 Subject: [PATCH 64/77] cleanup, license headers Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/bindings.cpp | 3 +- .../wave/execution_engine/buffer_utils.cpp | 3 +- .../wave/execution_engine/buffer_utils.h | 3 +- .../execution_engine/execution_engine.cpp | 14 ++++---- .../wave/execution_engine/execution_engine.h | 3 +- .../wave/execution_engine/execution_engine.py | 36 +++++++------------ 6 files changed, 26 insertions(+), 36 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 46cdbf3cd..782d6ad81 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -1,4 +1,5 @@ -// Copyright 2025 The IREE Authors +// Copyright 2025 The Wave Authors +// // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index ec48ba757..7b2c7e5f4 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -1,4 +1,5 @@ -// Copyright 2025 The IREE Authors +// Copyright 2025 The Wave Authors +// // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index 740dd0562..f1b9eec6c 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -1,4 +1,5 @@ -// Copyright 2025 The IREE Authors +// Copyright 2025 The Wave Authors +// // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index 697509810..a0b945091 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -1,4 +1,5 @@ -// Copyright 2025 The IREE Authors +// Copyright 2025 The Wave Authors +// // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -15,7 +16,11 @@ #include #include #include +#include +#include #include +#include +#include #include #include #include @@ -33,13 +38,6 @@ #include #include -#include -#include -#include -#include - -#include - #define DEBUG_TYPE "wave-execution-engine" static void initializeLLVMTarget() { diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.h b/wave_lang/kernel/wave/execution_engine/execution_engine.h index 66868f124..3382fdeef 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.h +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.h @@ -1,4 +1,5 @@ -// Copyright 2025 The IREE Authors +// Copyright 2025 The Wave Authors +// // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 4d75fd718..7e67bc5dd 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -69,6 +69,10 @@ def _load_library(lib_basename: str): return ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) +def _get_symbol(lib: ctypes.CDLL, name: str) -> ctypes.c_void_p: + return ctypes.cast(getattr(lib, name), ctypes.c_void_p).value + + def _load_runtime_helpers(): """ Load the wave_runtime_helpers shared library and return symbol addresses. @@ -82,25 +86,16 @@ def _load_runtime_helpers(): Raises: RuntimeError: If the library cannot be found or loaded """ - import ctypes - lib = _load_library("wave_runtime_helpers") symbol_map = {} - wave_get_buffer_addr = ctypes.cast( - lib._mlir_ciface_wave_get_buffer, ctypes.c_void_p - ).value - symbol_map["_mlir_ciface_wave_get_buffer"] = wave_get_buffer_addr - - wave_get_int64_addr = ctypes.cast(lib.wave_get_int64, ctypes.c_void_p).value - symbol_map["wave_get_int64"] = wave_get_int64_addr - - wave_get_float64_addr = ctypes.cast(lib.wave_get_float64, ctypes.c_void_p).value - symbol_map["wave_get_float64"] = wave_get_float64_addr - - wave_get_dim_addr = ctypes.cast(lib.wave_get_dim, ctypes.c_void_p).value - symbol_map["wave_get_dim"] = wave_get_dim_addr + symbol_map["_mlir_ciface_wave_get_buffer"] = _get_symbol( + lib, "_mlir_ciface_wave_get_buffer" + ) + symbol_map["wave_get_int64"] = _get_symbol(lib, "wave_get_int64") + symbol_map["wave_get_float64"] = _get_symbol(lib, "wave_get_float64") + symbol_map["wave_get_dim"] = _get_symbol(lib, "wave_get_dim") return symbol_map @@ -117,21 +112,14 @@ def _load_hip_runtime(): Raises: RuntimeError: If the library cannot be found or loaded """ - import ctypes - lib = _load_library("wave_hip_runtime") - # Load HIP functions eagerly lib.load_functions() symbol_map = {} - # Register HIP runtime functions - wave_load_kernel_addr = ctypes.cast(lib.wave_load_kernel, ctypes.c_void_p).value - symbol_map["wave_load_kernel"] = wave_load_kernel_addr - - wave_launch_kernel_addr = ctypes.cast(lib.wave_launch_kernel, ctypes.c_void_p).value - symbol_map["wave_launch_kernel"] = wave_launch_kernel_addr + symbol_map["wave_load_kernel"] = _get_symbol(lib, "wave_load_kernel") + symbol_map["wave_launch_kernel"] = _get_symbol(lib, "wave_launch_kernel") return symbol_map From aada62f929a86012c97c211c3f9767630f02f474 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 16:16:32 +0100 Subject: [PATCH 65/77] miltiline literals and comment Signed-off-by: Ivan Butygin --- .../kernel/wave/execution_engine/bindings.cpp | 108 ++++++++++-------- .../wave/execution_engine/buffer_utils.h | 3 +- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/bindings.cpp b/wave_lang/kernel/wave/execution_engine/bindings.cpp index 782d6ad81..b5d081afd 100644 --- a/wave_lang/kernel/wave/execution_engine/bindings.cpp +++ b/wave_lang/kernel/wave/execution_engine/bindings.cpp @@ -70,20 +70,22 @@ NB_MODULE(wave_execution_engine, m) { }; }, nb::arg("symbols"), - "Set symbol map from a dictionary of symbol names to addresses.\n\n" - "Args:\n" - " symbols: Dictionary mapping symbol names (str) to addresses " - "(int)\n\n" - "Example:\n" - " options.set_symbol_map({'my_function': 0x12345678})"); + R"(Set symbol map from a dictionary of symbol names to addresses. + +Args: + symbols: Dictionary mapping symbol names (str) to addresses (int) + +Example: + options.set_symbol_map({'my_function': 0x12345678}))"); // Bind ExecutionEngine class nb::class_(m, "ExecutionEngine", nb::is_weak_referenceable()) .def(nb::init(), nb::arg("options"), - "Create a new ExecutionEngine with the given options.\n\n" - "Args:\n" - " options: ExecutionEngineOptions to configure the engine") + R"(Create a new ExecutionEngine with the given options. + +Args: + options: ExecutionEngineOptions to configure the engine)") .def( "load_module", [](wave::ExecutionEngine &self, MlirModule cModule) { @@ -93,13 +95,16 @@ NB_MODULE(wave_execution_engine, m) { return reinterpret_cast(handle); }, nb::arg("module"), - "Compile and load an MLIR module into the execution engine.\n\n" - "Args:\n" - " module: MLIR module (MlirModule from MLIR C API)\n\n" - "Returns:\n" - " Module handle as integer\n\n" - "Raises:\n" - " RuntimeError: If compilation or loading fails") + R"(Compile and load an MLIR module into the execution engine. + +Args: + module: MLIR module (MlirModule from MLIR C API) + +Returns: + Module handle as integer + +Raises: + RuntimeError: If compilation or loading fails)") .def( "load_module_from_bytecode", [](wave::ExecutionEngine &self, nb::bytes bytecode) { @@ -110,13 +115,16 @@ NB_MODULE(wave_execution_engine, m) { return reinterpret_cast(handle); }, nb::arg("bytecode"), - "Deserialize MLIR bytecode and load it into the execution engine.\n\n" - "Args:\n" - " bytecode: MLIR module serialized as bytecode (bytes)\n\n" - "Returns:\n" - " Module handle as integer\n\n" - "Raises:\n" - " RuntimeError: If deserialization, compilation or loading fails") + R"(Deserialize MLIR bytecode and load it into the execution engine. + +Args: + bytecode: MLIR module serialized as bytecode (bytes) + +Returns: + Module handle as integer + +Raises: + RuntimeError: If deserialization, compilation or loading fails)") .def( "load_module_from_text", [](wave::ExecutionEngine &self, const std::string &mlirText) { @@ -126,22 +134,26 @@ NB_MODULE(wave_execution_engine, m) { return reinterpret_cast(handle); }, nb::arg("mlir_text"), - "Parse MLIR text and load it into the execution engine.\n\n" - "Args:\n" - " mlir_text: MLIR module as text string\n\n" - "Returns:\n" - " Module handle as integer\n\n" - "Raises:\n" - " RuntimeError: If parsing, compilation or loading fails") + R"(Parse MLIR text and load it into the execution engine. + +Args: + mlir_text: MLIR module as text string + +Returns: + Module handle as integer + +Raises: + RuntimeError: If parsing, compilation or loading fails)") .def( "release_module", [](wave::ExecutionEngine &self, uintptr_t handle) { self.releaseModule(reinterpret_cast(handle)); }, nb::arg("handle"), - "Release a loaded module from the execution engine.\n\n" - "Args:\n" - " handle: Module handle returned from load_module") + R"(Release a loaded module from the execution engine. + +Args: + handle: Module handle returned from load_module)") .def( "lookup", [](const wave::ExecutionEngine &self, uintptr_t handle, @@ -152,18 +164,22 @@ NB_MODULE(wave_execution_engine, m) { return reinterpret_cast(ptr); }, nb::arg("handle"), nb::arg("name"), - "Look up a function in a loaded module.\n\n" - "Args:\n" - " handle: Module handle returned from load_module\n" - " name: Name of the function to look up\n\n" - "Returns:\n" - " Function address as integer\n\n" - "Raises:\n" - " RuntimeError: If function lookup fails") + R"(Look up a function in a loaded module. + +Args: + handle: Module handle returned from load_module + name: Name of the function to look up + +Returns: + Function address as integer + +Raises: + RuntimeError: If function lookup fails)") .def("dump_to_object_file", &wave::ExecutionEngine::dumpToObjectFile, - nb::arg("filename"), - "Dump compiled object code to a file.\n\n" - "Note: Object cache must be enabled in ExecutionEngineOptions.\n\n" - "Args:\n" - " filename: Path to output file"); + nb::arg("filename"), R"(Dump compiled object code to a file. + +Note: Object cache must be enabled in ExecutionEngineOptions. + +Args: + filename: Path to output file)"); } diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index f1b9eec6c..da5d2f4d2 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -19,7 +19,8 @@ template struct StridedMemRefType { int64_t strides[N]; // Stride of each dimension in elements }; -/// Rank-1 memref descriptor for memref +/// Rank-1 memref descriptor for memref, we need to pass this to +// memref.view op which only accepts 1D i8 memrefs. using MemRef1Di8 = StridedMemRefType; extern "C" { From 15aa357bdd61450c97dee8699b24f31ab9ce38b7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 16:50:16 +0100 Subject: [PATCH 66/77] better wave-opt error handling Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 5558a5db1..3b291a273 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -438,8 +438,8 @@ def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Modu text=True, ) except subprocess.CalledProcessError as e: - print(e.stderr) - raise e + error_msg = f"Subprocess failed with return code {e.returncode}.\nStderr output:\n{e.stderr}" + raise RuntimeError(error_msg) from e with module.context: return Module.parse(result) From 3cbe6163641ce69f178a7cdde3ecb46382dd6981 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 16:52:59 +0100 Subject: [PATCH 67/77] typos Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/execution_engine.cpp | 2 +- wave_lang/kernel/wave/execution_engine/execution_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp index a0b945091..f6c089b69 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.cpp +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.cpp @@ -256,7 +256,7 @@ wave::ExecutionEngine::loadModule(mlir::ModuleOp m) { if (!llvmModule) return makeStringError("could not convert to LLVM IR"); - // Add a ThreadSafemodule to the engine and return. + // Add a ThreadSafeModule to the engine and return. llvm::orc::ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx)); if (transformer) cantFail(tsm.withModuleDo( diff --git a/wave_lang/kernel/wave/execution_engine/execution_engine.py b/wave_lang/kernel/wave/execution_engine/execution_engine.py index 7e67bc5dd..d51c7d2fc 100644 --- a/wave_lang/kernel/wave/execution_engine/execution_engine.py +++ b/wave_lang/kernel/wave/execution_engine/execution_engine.py @@ -64,7 +64,7 @@ def _load_library(lib_basename: str): lib_path = module_dir / lib_name if not lib_path.exists(): - raise RuntimeError(f"{lib_basename} library not found at {lib_path}. ") + raise RuntimeError(f"{lib_basename} library not found at {lib_path}.") return ctypes.CDLL(str(lib_path), mode=ctypes.RTLD_GLOBAL) From 435105ffade950f2d572a83395769686b72300b8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 24 Nov 2025 17:08:15 +0100 Subject: [PATCH 68/77] simplify wave_get_buffer Signed-off-by: Ivan Butygin --- .../wave/execution_engine/buffer_utils.cpp | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index 7b2c7e5f4..b6f427bbd 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -50,31 +50,12 @@ extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, "wave_get_buffer: data_ptr() did not return a valid pointer"); } - // Get tensor.numel() method and call it - PyObjectPtr numel_method(PyObject_GetAttrString(obj_ptr, "numel")); - if (!numel_method) { - PyErr_Clear(); - throw std::runtime_error( - "wave_get_buffer: Object does not have 'numel' attribute"); - } - - PyObjectPtr numel_result(PyObject_CallNoArgs(numel_method.get())); - if (!numel_result) { - PyErr_Clear(); - throw std::runtime_error("wave_get_buffer: Failed to call numel()"); - } - - int64_t numel = PyLong_AsLongLong(numel_result.get()); - if (PyErr_Occurred()) { - PyErr_Clear(); - throw std::runtime_error("wave_get_buffer: numel() returned invalid value"); - } - // Fill in the memref descriptor ret->basePtr = static_cast(data_ptr); ret->data = static_cast(data_ptr); ret->offset = 0; - ret->sizes[0] = numel; + // Actual size doesn't matter we will cast it to 0D memref immediately. + ret->sizes[0] = -1; ret->strides[0] = 1; } From 24710f470a90bf43c93f364a24de1bdd1d81f02f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 26 Nov 2025 15:08:22 +0100 Subject: [PATCH 69/77] fix search path Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 3b291a273..9a7fd6bb0 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -13,7 +13,6 @@ import sys import math from typing import Any, Sequence, Optional -import importlib from functools import lru_cache from wave_lang.kernel.wave.compile_options import WaveCompileOptions @@ -175,7 +174,7 @@ def replace_ops_and_collect_subspans(op: Operation) -> WalkResult: return local_module.get_asm(binary=False, print_generic_op_form=True) -def find_local_water_binary_path(name: str) -> Optional[str]: +def get_binary(name: str) -> Optional[str]: this_path = Path(__file__).parent tool_path = this_path / "water_mlir" / "bin" / name if not tool_path.is_file() or not os.access(tool_path, os.X_OK): @@ -186,25 +185,16 @@ def find_local_water_binary_path(name: str) -> Optional[str]: @lru_cache def is_water_available() -> bool: """Returns True if the water_mlir package is available.""" - if (Path(__file__).parent / "water_mlir").exists(): - return True - - return importlib.util.find_spec("water_mlir") is not None + return (Path(__file__).parent / "water_mlir" / "water_mlir").exists() @lru_cache -def get_water_binary_path() -> str: - path = find_local_water_binary_path("water-opt") - if path: - return path +def get_water_opt() -> str: + path = get_binary("water-opt") + if path is None: + raise RuntimeError("water-opt binary not found") - try: - from water_mlir import binaries as water_bin - except ImportError as err: - raise RuntimeError( - "optional water_mlir module not installed but its use is requested" - ) from err - return water_bin.find_binary("water-opt") + return path def make_linear_pass_pipeline( From 09a304cbd5f34f8c85a2ca604782e4ac858e2311 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 26 Nov 2025 15:11:15 +0100 Subject: [PATCH 70/77] fix func call Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 9a7fd6bb0..98de1a660 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -239,7 +239,7 @@ def make_pass_arguments( def water_leak_in_bounds_check(module: Module, override_ir: str = ""): - binary = get_water_binary_path() + binary = get_water_opt() generic_mlir = _deiree(module) if override_ir == "" else override_ir pipeline = [ ( @@ -394,7 +394,7 @@ def diagnostic_from_json( def water_lowering_pipeline(module: Module, options: WaveCompileOptions) -> Module: - binary = get_water_binary_path() + binary = get_water_opt() mlir_asm = module.operation.get_asm() target_chip = options.target opt_pass = "composite-fixed-point-pass", { From 2dcd0de8a6adf5813ea44e524421ee9916ab8859 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 26 Nov 2025 15:12:32 +0100 Subject: [PATCH 71/77] comment Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/buffer_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index da5d2f4d2..fbe97ed6a 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -20,7 +20,7 @@ template struct StridedMemRefType { }; /// Rank-1 memref descriptor for memref, we need to pass this to -// memref.view op which only accepts 1D i8 memrefs. +/// memref.view op which only accepts 1D i8 memrefs. using MemRef1Di8 = StridedMemRefType; extern "C" { From 3564e798cf865d8970434c43800f8f1eb968a9be Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 26 Nov 2025 15:19:31 +0100 Subject: [PATCH 72/77] fix func Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/water.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 98de1a660..b87429047 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -174,7 +174,7 @@ def replace_ops_and_collect_subspans(op: Operation) -> WalkResult: return local_module.get_asm(binary=False, print_generic_op_form=True) -def get_binary(name: str) -> Optional[str]: +def find_binary(name: str) -> Optional[str]: this_path = Path(__file__).parent tool_path = this_path / "water_mlir" / "bin" / name if not tool_path.is_file() or not os.access(tool_path, os.X_OK): @@ -190,7 +190,7 @@ def is_water_available() -> bool: @lru_cache def get_water_opt() -> str: - path = get_binary("water-opt") + path = find_binary("water-opt") if path is None: raise RuntimeError("water-opt binary not found") From 85983d35d8ef421f7e3ccdba25c48577868823ed Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 27 Nov 2025 15:47:47 +0100 Subject: [PATCH 73/77] fix runtime funcs Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/execution_engine/buffer_utils.cpp | 7 ++++--- wave_lang/kernel/wave/execution_engine/buffer_utils.h | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp index b6f427bbd..7b4b05d7a 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.cpp @@ -59,7 +59,7 @@ extern "C" void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, ret->strides[0] = 1; } -extern "C" int64_t wave_get_int64(PyObject *obj_ptr) { +extern "C" int64_t _mlir_ciface_wave_get_int64(PyObject *obj_ptr) { GILState gil_state; int64_t value = PyLong_AsLongLong(obj_ptr); @@ -71,7 +71,7 @@ extern "C" int64_t wave_get_int64(PyObject *obj_ptr) { return value; } -extern "C" double wave_get_float64(PyObject *obj_ptr) { +extern "C" double _mlir_ciface_wave_get_float64(PyObject *obj_ptr) { GILState gil_state; double value = PyFloat_AsDouble(obj_ptr); @@ -83,7 +83,8 @@ extern "C" double wave_get_float64(PyObject *obj_ptr) { return value; } -extern "C" int64_t wave_get_dim(PyObject *obj_ptr, int32_t dim_idx) { +extern "C" int64_t _mlir_ciface_wave_get_dim(PyObject *obj_ptr, + int32_t dim_idx) { GILState gil_state; // Get tensor.size() method diff --git a/wave_lang/kernel/wave/execution_engine/buffer_utils.h b/wave_lang/kernel/wave/execution_engine/buffer_utils.h index fbe97ed6a..fb9cc280a 100644 --- a/wave_lang/kernel/wave/execution_engine/buffer_utils.h +++ b/wave_lang/kernel/wave/execution_engine/buffer_utils.h @@ -41,11 +41,11 @@ void _mlir_ciface_wave_get_buffer(MemRef1Di8 *ret, PyObject *obj); /// Extract an int64_t value from a PyObject. /// Throws std::runtime_error if conversion fails. -int64_t wave_get_int64(PyObject *obj); +int64_t _mlir_ciface_wave_get_int64(PyObject *obj); /// Extract a double value from a PyObject. /// Throws std::runtime_error if conversion fails. -double wave_get_float64(PyObject *obj); +double _mlir_ciface_wave_get_float64(PyObject *obj); /// Extract the size of a specific dimension from a PyObject (PyTorch tensor). /// @@ -59,5 +59,5 @@ double wave_get_float64(PyObject *obj); /// Throws: /// std::runtime_error if the object doesn't have a size() method or /// if the dimension index is invalid -int64_t wave_get_dim(PyObject *obj, int32_t dim_idx); +int64_t _mlir_ciface_wave_get_dim(PyObject *obj, int32_t dim_idx); } From cb29e98852cb9eb4fbc5d981e9c9208ba899936c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 27 Nov 2025 15:49:13 +0100 Subject: [PATCH 74/77] enable shared libs Signed-off-by: Ivan Butygin --- .github/workflows/ci-gpu.yaml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/ci-gpu.yaml b/.github/workflows/ci-gpu.yaml index 805e32a77..8189a4959 100644 --- a/.github/workflows/ci-gpu.yaml +++ b/.github/workflows/ci-gpu.yaml @@ -237,10 +237,7 @@ jobs: fail-fast: false matrix: name: [linux-mi325-1gpu-ossci-iree-org] - shared_libs: ["OFF"] - # TODO: we are linking water shared libs with static LLVM libraries, - # which doesn't really work if multiple of them linked with LLVM proper. - # shared_libs: ["ON", "OFF"] + shared_libs: ["ON", "OFF"] runs-on: ${{ matrix.name }} timeout-minutes: 60 needs: build_llvm_linux From ce34549641f2024e2b6f9088650ed155a44a7894 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 27 Nov 2025 15:50:14 +0100 Subject: [PATCH 75/77] remove duplicated test Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/water_lowering.py | 111 ------------------------ 1 file changed, 111 deletions(-) delete mode 100644 lit_tests/kernel/wave/water_lowering.py diff --git a/lit_tests/kernel/wave/water_lowering.py b/lit_tests/kernel/wave/water_lowering.py deleted file mode 100644 index 057a7d350..000000000 --- a/lit_tests/kernel/wave/water_lowering.py +++ /dev/null @@ -1,111 +0,0 @@ -# REQUIRES: water -# RUN: python %s | FileCheck %s - -import wave_lang.kernel.lang as tkl -import wave_lang.kernel.wave as tkw -from wave_lang.kernel.lang.global_symbols import * -from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile -from wave_lang.kernel.wave.utils.general_utils import ( - run_test, -) -from wave_lang.support.location_config import ( - LocationCaptureConfig, - LocationCaptureLevel, -) - -M = tkl.sym.M -N = tkl.sym.N -K = tkl.sym.K -B = tkl.sym.B -BLOCK_M = tkl.sym.BLOCK_M -BLOCK_N = tkl.sym.BLOCK_N -BLOCK_K = tkl.sym.BLOCK_K -BLOCK_B = tkl.sym.BLOCK_B -LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD -STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD -ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE -ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 - - -def get_wave_compile_options( - canonicalize: bool = False, - dynamic_symbols=[], - additional_symbols={}, - location_capture_config=LocationCaptureConfig( - level=LocationCaptureLevel.FILE_LINE_COL - ), - drop_debug_info_before_mlir=True, -): - bindings = { - M: 16, - N: 16, - K: 16, - BLOCK_M: 16, - BLOCK_N: 16, - BLOCK_K: 16, - ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, - } - bindings.update(additional_symbols) - - # Remove dynamic symbols from the bindings. - for sym in dynamic_symbols: - if sym in bindings: - del bindings[sym] - - return WaveCompileOptions( - subs=bindings, - canonicalize=canonicalize, - dynamic_symbols=dynamic_symbols, - compile_to_mlir=True, - location_capture_config=location_capture_config, - drop_debug_info_before_mlir=drop_debug_info_before_mlir, - use_water_pipeline=True, - ) - - -@run_test -def test_read_write(): - constraints: list[tkw.Constraint] = [ - tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: 16, N: 16}) - ] - constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WaveConstraint(M, BLOCK_M)] - constraints += [tkw.WaveConstraint(N, BLOCK_N)] - - @tkw.wave(constraints) - def read_write( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - ): - res = tkw.read(a) - tkw.write(res, b) - - read_write = wave_compile(get_wave_compile_options(canonicalize=True), read_write) - print(read_write.asm) - - # CHECK-LABEL: test_read_write - # CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 - (s0 floordiv 64) * 48)> - # CHECK: gpu.module @gpu_module - # CHECK: gpu.func @read_write - # CHECK-SAME: (%[[D0:.*]]: memref, %[[D1:.*]]: memref) - # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - # CHECK-DAG: %[[thread_id_x:.*]] = gpu.thread_id x - # CHECK: %[[S0:.*]] = memref.reinterpret_cast %[[D0]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> - # CHECK: %[[S1:.*]] = memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> - # CHECK: %[[I0:.*]] = affine.apply #[[MAP0]]()[%[[thread_id_x]]] - # CHECK: %[[V:.*]] = vector.load %[[S0]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> - # CHECK: vector.store %[[V]], %[[S1]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> - # CHECK: return - - # CHECK-LABEL: func.func @isolated_benchmark - # CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr) - # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index - # CHECK: %[[BUF1:.*]] = call @wave_get_buffer(%[[ARG1]]) : (!llvm.ptr) -> memref - # CHECK: %[[VIEW1:.*]] = memref.view %[[BUF1]][%[[C0]]][] : memref to memref - # CHECK: %[[BUF2:.*]] = call @wave_get_buffer(%[[ARG2]]) : (!llvm.ptr) -> memref - # CHECK: %[[VIEW2:.*]] = memref.view %[[BUF2]][%[[C0]]][] : memref to memref - # CHECK: gpu.launch_func @gpu_module::@read_write blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C64]], %[[C1]], %[[C1]]) args(%[[VIEW1]] : memref, %[[VIEW2]] : memref) - # CHECK: return From e3ba1a5475dc8aec97be800bc208e0e9a590cbcd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 27 Nov 2025 15:52:12 +0100 Subject: [PATCH 76/77] cleanup Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/codegen/emitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wave_lang/kernel/wave/codegen/emitter.py b/wave_lang/kernel/wave/codegen/emitter.py index 306aa8932..e236fcdd0 100644 --- a/wave_lang/kernel/wave/codegen/emitter.py +++ b/wave_lang/kernel/wave/codegen/emitter.py @@ -20,7 +20,6 @@ from wave_lang.kernel.ops.wave_ops import get_custom from wave_lang.kernel.lang import Memory from wave_lang.kernel.lang.kernel_buffer import KernelBuffer -from wave_lang.kernel.compiler.kernel_codegen import BindingType from wave_lang.kernel.compiler.utils import strides_from_symbolic_shape from wave_lang.kernel.lang.global_symbols import * from wave_lang.support.logging import get_logger From 15f9bd8a3802d842aa91f32d59d39d7802f09a8b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 27 Nov 2025 16:04:25 +0100 Subject: [PATCH 77/77] refac includes Signed-off-by: Ivan Butygin --- .../wave/execution_engine/CMakeLists.txt | 3 +++ .../execution_engine/wave_hip_runtime.cpp | 25 ++----------------- wave_lang/kernel/wave/runtime/hip_types.h | 1 + 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt index 39c86629f..76bedeed2 100644 --- a/wave_lang/kernel/wave/execution_engine/CMakeLists.txt +++ b/wave_lang/kernel/wave/execution_engine/CMakeLists.txt @@ -49,6 +49,9 @@ add_library(wave_hip_runtime SHARED wave_hip_runtime.cpp ) +# Add other runtime dir so we can include hip_types.h +target_include_directories(wave_hip_runtime PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/../runtime") + nanobind_add_module(wave_execution_engine NB_STATIC execution_engine.cpp bindings.cpp diff --git a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp index 4a2c68493..dd3eeb77a 100644 --- a/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp +++ b/wave_lang/kernel/wave/execution_engine/wave_hip_runtime.cpp @@ -9,6 +9,8 @@ #include #include +#include "hip_types.h" + #if defined(__linux__) #include // dlopen, dlsym, dlerror using module_handle_t = void *; @@ -20,29 +22,6 @@ using module_handle_t = HMODULE; #error "Unsupported platform" #endif -// HIP constants and types -#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void *)0x01) -#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void *)0x02) -#define HIP_LAUNCH_PARAM_END ((void *)0x03) - -using hipError_t = int; -using hipStream_t = void *; -using hipFunction_t = void *; -using hipModule_t = void *; - -using hipModuleLaunchKernel_t = hipError_t (*)(hipFunction_t, unsigned int, - unsigned int, unsigned int, - unsigned int, unsigned int, - unsigned int, unsigned int, - hipStream_t, void **, void **); - -using hipGetErrorName_t = const char *(*)(hipError_t); -using hipGetErrorString_t = const char *(*)(hipError_t); -using hipModuleUnload_t = hipError_t (*)(hipModule_t); -using hipModuleLoadData_t = hipError_t (*)(hipModule_t *, const void *); -using hipModuleGetFunction_t = hipError_t (*)(hipFunction_t *, hipModule_t, - const char *); - // Global function pointers static hipModuleLaunchKernel_t hipModuleLaunchKernel = nullptr; static hipGetErrorName_t hipGetErrorName = nullptr; diff --git a/wave_lang/kernel/wave/runtime/hip_types.h b/wave_lang/kernel/wave/runtime/hip_types.h index a2e2b297d..633be7852 100644 --- a/wave_lang/kernel/wave/runtime/hip_types.h +++ b/wave_lang/kernel/wave/runtime/hip_types.h @@ -92,5 +92,6 @@ using hipGetErrorName_t = const char *(*)(hipError_t); using hipGetErrorString_t = const char *(*)(hipError_t); using hipModuleUnload_t = hipError_t (*)(hipModule_t); using hipModuleLoad_t = hipError_t (*)(hipModule_t *, const char *); +using hipModuleLoadData_t = hipError_t (*)(hipModule_t *, const void *); using hipModuleGetFunction_t = hipError_t (*)(hipFunction_t *, hipModule_t, const char *);