-
Notifications
You must be signed in to change notification settings - Fork 24
Devdocs #1684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Devdocs #1684
Changes from all commits
c12346f
a7895e1
0557e05
7a88503
9f34299
6344e6a
20c31f0
1434e4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Enzyme-JAX Developer Notes | ||
|
|
||
| ## Building the Project | ||
|
|
||
| ### Quick Build | ||
| ```bash | ||
| bazel build --repo_env=CC=clang-18 --color=yes -c dbg :enzymexlamlir-opt | ||
| ``` | ||
|
|
||
| ### Build Artifacts | ||
| - **Main tool**: `enzymexlamlir-opt` (bazel target: `:enzymexlamlir-opt`) | ||
| - This is the MLIR optimization tool driver for Enzyme-XLA | ||
| - Analogous to `mlir-opt`, drives compiler passes and transformations | ||
| - Located in: `src/enzyme_ad/jax/enzymexlamlir-opt.cpp` | ||
|
|
||
| - **Python wheel**: `bazel build :wheel` | ||
|
|
||
| ### Generate LSP Support | ||
| ```bash | ||
| bazel run :refresh_compile_commands | ||
| ``` | ||
|
|
||
| ## Project Structure | ||
|
|
||
| ### Core Components | ||
|
|
||
| #### 1. **Dialects** (`src/enzyme_ad/jax/Dialect/`) | ||
| MLIR dialects define custom operations and types for Enzyme-JAX. | ||
|
|
||
| - **EnzymeXLAOps.td** - Dialect operation definitions | ||
| - GPU operations: `kernel_call`, `memcpy`, `gpu_wrapper`, `gpu_block`, `gpu_thread` | ||
| - JIT/XLA operations: `jit_call`, `xla_wrapper` | ||
| - Linear algebra (BLAS/LAPACK): `symm`, `syrk`, `trmm`, `lu`, `getrf`, `gesvd`, etc. | ||
| - Special functions: Bessel functions, GELU, ReLU | ||
| - Utility operations: `memref2pointer`, `pointer2memref`, `subindex` | ||
|
|
||
| - **EnzymeXLAAttrs.td** - Custom attribute definitions (LAPACK enums, etc.) | ||
|
|
||
| #### 2. **Passes** (`src/enzyme_ad/jax/Passes/`) | ||
| MLIR passes implement transformations and optimizations. | ||
|
|
||
| - Tablegen definitions in `src/enzyme_ad/jax/Passes/Passes.td` | ||
| - **EnzymeHLOOpt.cpp** - Core optimization patterns for StableHLO and EnzymeXLA operations | ||
| This file contains (nearly) all the stablehlo tensor optimizations. | ||
|
|
||
| #### 3. **Transform Operations** (`src/enzyme_ad/jax/TransformOps/`) | ||
| In order to have more granular control over which pattern is applied, patterns are also registered as transform operations. | ||
| For example: | ||
| ``` | ||
| def AndPadPad : EnzymeHLOPatternOp< | ||
| "and_pad_pad"> { | ||
| let patterns = ["AndPadPad"]; | ||
| } | ||
| ``` | ||
| Exposes the `AndPadPad` pattern (defined in `EnzymeHLOOpt.cpp`) to `enzymexlamlir-opt`, so it can be used as: | ||
| ``` | ||
| enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=and_pad_pad" --transform-interpreter --enzyme-hlo-remove-transform -allow-unregistered-dialect input.mlir | ||
| ``` | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also a fourth section for derivative rules ? [cc @avik-pal ] |
||
| ## Common Development Tasks | ||
|
|
||
| ### Adding a New Optimization Pattern | ||
|
|
||
| 1. Define the pattern class in `src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp` | ||
| 2. Inherit from `mlir::OpRewritePattern<OpType>` | ||
| 3. Implement `matchAndRewrite()` method | ||
| 4. Register in `EnzymeHLOOptPass::runOnOperation()` | ||
| 5. Register as Transform operation in `TransformOps.td` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also say to add to primitives.py which I often forget
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Feel free to just push, otherwise I can do it later. I didn't even know about primitives.py 🙈 |
||
| 6. Add the pass to the appropriate pass list in `src/enzyme_ad/jax/primitives.py` | ||
|
|
||
| ### Adding a New Dialect Operation | ||
|
|
||
| 1. Define operation in `src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td` | ||
| 2. Specify arguments, results, and traits | ||
| 3. Implement operation class if needed in `src/enzyme_ad/jax/Dialect/Ops.cpp` | ||
| 4. TODO: write about derivative rules? | ||
|
|
||
| ## Testing | ||
|
|
||
| Run tests with: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also add a link to what a lit test is (e.g. syntax is command at top, file below, with checks, link to full spec), where tests go/etc? |
||
| ```bash | ||
| bazel test //test/... | ||
| ``` | ||
romanlee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| This runs the tests in | ||
|
|
||
| Most of the Enzyme-JaX tests use [lit](https://llvm.org/docs/CommandGuide/lit.html) for testing. | ||
| These tests are stored in `test/lit_tests`. | ||
| A lit test contains one or more run directives at the top a file. | ||
| e.g. in `test/lit_tests/if.mlir`: | ||
| ```mlir | ||
| // RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s | ||
| ``` | ||
| This instructs `lit` to run the `enzyme-hlo-opt` pass on `test/lit_tests/if.mlir`. | ||
| The output is fed to `FileCheck` which compares it against the expected result that is provided in comments in the file that start with `// CHECK`. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a bit more detail here?
for example enzymehloopt.cpp contains (nearly) all the stablehlo/linear algebra level transformations
we can also make a subsection for the (polygeist) raising passes?