Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
2cf9724
register everything
Hardcode84 Oct 22, 2025
a63e139
initial pipeline
Hardcode84 Oct 22, 2025
fc8ed5e
register llvm translation
Hardcode84 Oct 22, 2025
c7dd516
more pipeline
Hardcode84 Oct 22, 2025
3f594a8
return
Hardcode84 Oct 23, 2025
9c6369c
gpu-to-gpu-runtime
Hardcode84 Oct 23, 2025
f7433f5
WIP: Add LLVM ExecutionEngine stub
Hardcode84 Nov 9, 2025
1cf0010
ExecutionEngine WIP
Hardcode84 Nov 9, 2025
ec2b19d
Fix LLVM API compatibility in ExecutionEngine
Hardcode84 Nov 9, 2025
7e4c6c8
new wrapper
Hardcode84 Nov 10, 2025
8b34f65
execution engine
Hardcode84 Nov 10, 2025
8038c35
init llvm target
Hardcode84 Nov 11, 2025
2a4c7fd
execution_engine python wrapper
Hardcode84 Nov 11, 2025
1bfa8d8
load module from bytecode
Hardcode84 Nov 11, 2025
b3b09f5
bytecode loading
Hardcode84 Nov 11, 2025
42839d1
load from text
Hardcode84 Nov 11, 2025
254cb7a
buffer utils
Hardcode84 Nov 15, 2025
f21d597
host wrapper name
Hardcode84 Nov 15, 2025
413aa85
move ctypes
Hardcode84 Nov 15, 2025
6173205
use current stream
Hardcode84 Nov 15, 2025
dd23422
runtime
Hardcode84 Nov 16, 2025
0da8d14
fix import
Hardcode84 Nov 16, 2025
3cc0aed
expose symbolMap
Hardcode84 Nov 16, 2025
046b578
cleanup
Hardcode84 Nov 16, 2025
be6f25e
hip_runtime
Hardcode84 Nov 16, 2025
f7f458c
torch utils
Hardcode84 Nov 16, 2025
16fbe83
buffer_utils WIP
Hardcode84 Nov 16, 2025
a8f9eeb
python runtime
Hardcode84 Nov 16, 2025
9c88480
error handling
Hardcode84 Nov 16, 2025
0051194
module loading
Hardcode84 Nov 16, 2025
3b40efd
load hip funcs explicitly
Hardcode84 Nov 16, 2025
cfe5a09
scalar args
Hardcode84 Nov 16, 2025
e6218cd
fix test_dynamic_copy
Hardcode84 Nov 17, 2025
0735065
get_dim impl
Hardcode84 Nov 17, 2025
ae60dbb
erase original func
Hardcode84 Nov 17, 2025
41d2856
cleanup cmake
Hardcode84 Nov 17, 2025
7855c4c
more comments
Hardcode84 Nov 17, 2025
7d7af5a
install
Hardcode84 Nov 22, 2025
e466b9d
local water
Hardcode84 Nov 22, 2025
225f4bf
tests
Hardcode84 Nov 23, 2025
8db1648
refa coptions
Hardcode84 Nov 23, 2025
35e9ed6
simplify execution_engine
Hardcode84 Nov 23, 2025
d8d3fac
opt_pass
Hardcode84 Nov 23, 2025
342601e
print IR after all
Hardcode84 Nov 23, 2025
d8dbefc
rename class
Hardcode84 Nov 23, 2025
4f2bce4
test
Hardcode84 Nov 23, 2025
29c68c6
fix test
Hardcode84 Nov 23, 2025
460206e
register dialect
Hardcode84 Nov 23, 2025
18ace68
lib
Hardcode84 Nov 23, 2025
b8f1ba1
skip shared libs
Hardcode84 Nov 23, 2025
57b6a2f
fix tests
Hardcode84 Nov 23, 2025
2b77f89
lit test
Hardcode84 Nov 23, 2025
d714cf7
REQUIRES
Hardcode84 Nov 23, 2025
c3edecc
fix typo
Hardcode84 Nov 23, 2025
c8135b0
typos and cleanup
Hardcode84 Nov 23, 2025
debc49e
runtime pass lit
Hardcode84 Nov 24, 2025
3fcf325
typos
Hardcode84 Nov 24, 2025
67bf7ed
fix python
Hardcode84 Nov 24, 2025
ab44800
if(auto error = ...)
Hardcode84 Nov 24, 2025
3b414ab
require_water_and_ee
Hardcode84 Nov 24, 2025
5a491c3
cleanup
Hardcode84 Nov 24, 2025
4e6b923
make_linear_pass_pipeline doc
Hardcode84 Nov 24, 2025
43453f2
unlocal imports
Hardcode84 Nov 24, 2025
ab76725
cleanup, license headers
Hardcode84 Nov 24, 2025
aada62f
miltiline literals and comment
Hardcode84 Nov 24, 2025
15aa357
better wave-opt error handling
Hardcode84 Nov 24, 2025
3cbe616
typos
Hardcode84 Nov 24, 2025
435105f
simplify wave_get_buffer
Hardcode84 Nov 24, 2025
24710f4
fix search path
Hardcode84 Nov 26, 2025
09a304c
fix func call
Hardcode84 Nov 26, 2025
2dcd0de
comment
Hardcode84 Nov 26, 2025
3564e79
fix func
Hardcode84 Nov 26, 2025
85983d3
fix runtime funcs
Hardcode84 Nov 27, 2025
cb29e98
enable shared libs
Hardcode84 Nov 27, 2025
ce34549
remove duplicated test
Hardcode84 Nov 27, 2025
e3ba1a5
cleanup
Hardcode84 Nov 27, 2025
15f9bd8
refac includes
Hardcode84 Nov 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ def initialize_options(self):

if BUILD_WATER:
ext_modules += [
CMakeExtension(
"wave_execution_engine",
"wave_lang/kernel/wave/execution_engine",
install_dir="wave_lang/kernel/wave/execution_engine",
need_llvm=True,
),
Comment on lines +285 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Wave EE is under BUILD_WATER flag. Are we moving towards rebranding the MLIR layer of Wave Water? If so, let's update the documentation.

CMakeExtension(
"water",
"water",
Expand Down
15 changes: 14 additions & 1 deletion tests/kernel/wave/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,17 @@ def param_bool(name, shortname=None, values=None):
shortname = shortname or name
values = values or [False, True]
ids = [f"{shortname}" if v else f"no_{shortname}" for v in values]
return pytest.mark.parametrize(name, [pytest.param(v) for v in values], ids=ids)
return pytest.mark.parametrize(name, values, ids=ids)


def _is_water_and_ee_available() -> bool:
from wave_lang.kernel.wave.water import is_water_available
from wave_lang.kernel.wave.execution_engine import is_execution_engine_available

return is_water_available() and is_execution_engine_available()


require_water_and_ee = pytest.mark.skipif(
not _is_water_and_ee_available(),
reason="Water or execution engine are not available.",
)
177 changes: 177 additions & 0 deletions tests/kernel/wave/test_execution_engine_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2025 The IREE Authors
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Tests for the ExecutionEngine wrapper with weak reference caching.
"""

import gc
import os
import pytest
import weakref

from wave_lang.kernel.wave.execution_engine import (
clear_engine_cache,
get_execution_engine,
is_engine_cached,
is_execution_engine_available,
)

pytestmark = pytest.mark.skipif(
not is_execution_engine_available(),
reason="ExecutionEngine not available (C++ extension not built)",
)


ENV_VARS = [
"WAVE_ENABLE_OBJECT_CACHE",
"WAVE_ENABLE_GDB_LISTENER",
"WAVE_ENABLE_PERF_LISTENER",
]


@pytest.fixture
def clean_env():
"""Fixture to clean cache and environment variables before/after each test."""
clear_engine_cache()

# Save original env vars
orig_env = {}
for key in ENV_VARS:
orig_env[key] = os.environ.get(key)
if key in os.environ:
del os.environ[key]

yield

# Restore original env vars
for key, value in orig_env.items():
if value is not None:
os.environ[key] = value
elif key in os.environ:
del os.environ[key]

clear_engine_cache()
gc.collect()


def test_basic_creation(clean_env):
"""Test basic creation of execution engine."""
engine = get_execution_engine()
assert engine is not None
assert is_engine_cached()


def test_cache_returns_same_instance(clean_env):
"""Test that cache returns the same instance."""
engine1 = get_execution_engine()
engine2 = get_execution_engine()

# Should be the exact same object
assert engine1 is engine2


def test_weak_reference_cleanup(clean_env):
"""Test that engine is cleaned up when references are released."""
# Create engine and get weak reference
engine = get_execution_engine()
weak_ref = weakref.ref(engine)

# Verify it's cached
assert is_engine_cached()

# Delete strong reference
del engine
gc.collect()

# Weak reference should be dead
assert weak_ref() is None

# Cache should report no engine
assert not is_engine_cached()


def test_clear_cache(clean_env):
"""Test that clear_cache removes cached engine."""
engine = get_execution_engine()
assert is_engine_cached()

# Clear cache
clear_engine_cache()

# Cache should be empty
assert not is_engine_cached()

# Engine should still be alive (we have strong reference)
assert engine is not None


def test_env_var_object_cache(clean_env):
"""Test WAVE_ENABLE_OBJECT_CACHE environment variable."""
os.environ["WAVE_ENABLE_OBJECT_CACHE"] = "1"

engine = get_execution_engine()
assert engine is not None
# We can't directly test if object cache is enabled without
# accessing internal state, but we verify the engine was created


def test_env_var_gdb_listener(clean_env):
"""Test WAVE_ENABLE_GDB_LISTENER environment variable."""
os.environ["WAVE_ENABLE_GDB_LISTENER"] = "1"

engine = get_execution_engine()
assert engine is not None


def test_env_var_perf_listener(clean_env):
"""Test WAVE_ENABLE_PERF_LISTENER environment variable."""
os.environ["WAVE_ENABLE_PERF_LISTENER"] = "1"

engine = get_execution_engine()
assert engine is not None


def test_all_env_vars(clean_env):
"""Test all environment variables together."""
os.environ["WAVE_ENABLE_OBJECT_CACHE"] = "1"
os.environ["WAVE_ENABLE_GDB_LISTENER"] = "1"
os.environ["WAVE_ENABLE_PERF_LISTENER"] = "1"

engine = get_execution_engine()
assert engine is not None
assert is_engine_cached()


def test_multiple_references(clean_env):
"""Test that multiple references to the same engine work correctly."""
engine1 = get_execution_engine()
engine2 = get_execution_engine()
engine3 = get_execution_engine()

# All should be the same object
assert engine1 is engine2
assert engine2 is engine3


def test_engine_interface(clean_env):
"""Test that the engine provides the correct interface."""
engine = get_execution_engine()

# Check that required methods exist
assert hasattr(engine, "load_module")
assert hasattr(engine, "release_module")
assert hasattr(engine, "lookup")
assert hasattr(engine, "dump_to_object_file")

# Check that methods are callable
assert callable(engine.load_module)
assert callable(engine.release_module)
assert callable(engine.lookup)
assert callable(engine.dump_to_object_file)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
21 changes: 17 additions & 4 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
param_bool,
perf_test,
require_cdna3,
require_e2e,
require_cdna_2_or_3_or_4,
require_cdna4,
require_cdna_2_or_3_or_4,
require_e2e,
require_rdna4,
require_water_and_ee,
)
from .common.shapes import get_test_shapes as get_common_test_shape

Expand Down Expand Up @@ -138,8 +139,13 @@ def test(
@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy"))
@param_bool("use_buffer_ops", "buf_ops")
@param_bool(
"use_water_pipeline",
"water",
values=[False, pytest.param(True, marks=require_water_and_ee)],
)
@check_leaks
def test_copy(shape, use_buffer_ops, run_bench):
def test_copy(shape, use_buffer_ops, run_bench, use_water_pipeline):
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
Expand Down Expand Up @@ -183,6 +189,7 @@ def test(
canonicalize=True,
run_bench=run_bench,
use_buffer_ops=use_buffer_ops,
use_water_pipeline=use_water_pipeline,
)
options = set_default_run_config(options)
test = wave_compile(options, test)
Expand All @@ -194,7 +201,12 @@ def test(
@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy"))
@param_bool("use_buffer_ops", "buf_ops")
def test_dynamic_copy(shape, use_buffer_ops, run_bench):
@param_bool(
"use_water_pipeline",
"water",
values=[False, pytest.param(True, marks=require_water_and_ee)],
)
def test_dynamic_copy(shape, use_buffer_ops, run_bench, use_water_pipeline):
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
Expand Down Expand Up @@ -238,6 +250,7 @@ def test(
canonicalize=True,
run_bench=run_bench,
use_buffer_ops=use_buffer_ops,
use_water_pipeline=use_water_pipeline,
)
options = set_default_run_config(options)
test = wave_compile(options, test)
Expand Down
5 changes: 5 additions & 0 deletions water/tools/water-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ set(LIBS

MLIRWaterTestTransforms
MLIRWaterTestDialect

MLIRRegisterAllDialects
MLIRRegisterAllExtensions
MLIRRegisterAllPasses
Comment on lines +16 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't. Just pick what you need. Otherwise this significantly increases build times and distribution sizes, and is near impossible to undo.

MLIRToLLVMIRTranslationRegistration
)

add_llvm_executable(water-opt
Expand Down
19 changes: 8 additions & 11 deletions water/tools/water-opt/water-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"

Expand All @@ -40,20 +42,15 @@ int main(int argc, char **argv) {
mlir::water::registerPasses();
mlir::water::test::registerAllPasses();
wave::registerPasses();
mlir::arith::registerArithIntRangeOptsPass();
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
mlir::registerLoopInvariantCodeMotionPass();
mlir::registerLowerAffinePass();

mlir::registerAllPasses();
mlir::DialectRegistry registry;
registry.insert<mlir::affine::AffineDialect, mlir::amdgpu::AMDGPUDialect,
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::func::FuncDialect, mlir::gpu::GPUDialect,
mlir::LLVM::LLVMDialect, mlir::ROCDL::ROCDLDialect,
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
mlir::vector::VectorDialect, wave::WaveDialect>();
mlir::registerAllDialects(registry);
mlir::registerAllExtensions(registry);
Comment on lines +48 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't.


mlir::registerAllGPUToLLVMIRTranslations(registry);

registry.insert<wave::WaveDialect>();
mlir::water::test::registerWaterTestDialect(registry);

return mlir::asMainReturnCode(
Expand Down
Loading