Skip to content
Draft

Devdocs #1684

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions DEVDOCS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Enzyme-JAX Developer Notes

## Building the Project

### Quick Build
```bash
bazel build --repo_env=CC=clang-18 --color=yes -c dbg :enzymexlamlir-opt
```

### Build Artifacts
- **Main tool**: `enzymexlamlir-opt` (bazel target: `:enzymexlamlir-opt`)
- This is the MLIR optimization tool driver for Enzyme-XLA
- Analogous to `mlir-opt`, drives compiler passes and transformations
- Located in: `src/enzyme_ad/jax/enzymexlamlir-opt.cpp`

- **Python wheel**: `bazel build :wheel`

### Generate LSP Support
```bash
bazel run :refresh_compile_commands
```

## Project Structure

### Core Components

#### 1. **Dialects** (`src/enzyme_ad/jax/Dialect/`)
MLIR dialects define custom operations and types for Enzyme-JAX.

- **EnzymeXLAOps.td** - Dialect operation definitions
- GPU operations: `kernel_call`, `memcpy`, `gpu_wrapper`, `gpu_block`, `gpu_thread`
- JIT/XLA operations: `jit_call`, `xla_wrapper`
- Linear algebra (BLAS/LAPACK): `symm`, `syrk`, `trmm`, `lu`, `getrf`, `gesvd`, etc.
- Special functions: Bessel functions, GELU, ReLU
- Utility operations: `memref2pointer`, `pointer2memref`, `subindex`

- **EnzymeXLAAttrs.td** - Custom attribute definitions (LAPACK enums, etc.)

#### 2. **Passes** (`src/enzyme_ad/jax/Passes/`)
MLIR passes implement transformations and optimizations.

- Tablegen definitions in `src/enzyme_ad/jax/Passes/Passes.td`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a bit more detail here?

for example enzymehloopt.cpp contains (nearly) all the stablehlo/linear algebra level transformations

we can also make a subsection for the (polygeist) raising passes?

- **EnzymeHLOOpt.cpp** - Core optimization patterns for StableHLO and EnzymeXLA operations
This file contains (nearly) all the stablehlo tensor optimizations.

#### 3. **Transform Operations** (`src/enzyme_ad/jax/TransformOps/`)
In order to have more granular control over which pattern is applied, patterns are also registered as transform operations.
For example:
```
def AndPadPad : EnzymeHLOPatternOp<
"and_pad_pad"> {
let patterns = ["AndPadPad"];
}
```
Exposes the `AndPadPad` pattern (defined in `EnzymeHLOOpt.cpp`) to `enzymexlamlir-opt`, so it can be used as:
```
enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=and_pad_pad" --transform-interpreter --enzyme-hlo-remove-transform -allow-unregistered-dialect input.mlir
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also a fourth section for derivative rules ? [cc @avik-pal ]

## Common Development Tasks

### Adding a New Optimization Pattern

1. Define the pattern class in `src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp`
2. Inherit from `mlir::OpRewritePattern<OpType>`
3. Implement `matchAndRewrite()` method
4. Register in `EnzymeHLOOptPass::runOnOperation()`
5. Register as Transform operation in `TransformOps.td`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also say to add to primitives.py which I often forget

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Feel free to just push, otherwise I can do it later. I didn't even know about primitives.py 🙈

6. Add the pass to the appropriate pass list in `src/enzyme_ad/jax/primitives.py`

### Adding a New Dialect Operation

1. Define operation in `src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td`
2. Specify arguments, results, and traits
3. Implement operation class if needed in `src/enzyme_ad/jax/Dialect/Ops.cpp`
4. TODO: write about derivative rules?

## Testing

Run tests with:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also add a link to what a lit test is (e.g. syntax is command at top, file below, with checks, link to full spec), where tests go/etc?

```bash
bazel test //test/...
```
This runs the tests in

Most of the Enzyme-JaX tests use [lit](https://llvm.org/docs/CommandGuide/lit.html) for testing.
These tests are stored in `test/lit_tests`.
A lit test contains one or more run directives at the top a file.
e.g. in `test/lit_tests/if.mlir`:
```mlir
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
```
This instructs `lit` to run the `enzyme-hlo-opt` pass on `test/lit_tests/if.mlir`.
The output is fed to `FileCheck` which compares it against the expected result that is provided in comments in the file that start with `// CHECK`.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Enzyme-JAX

Custom bindings for Enzyme automatic differentiation tool and interfacing with
JAX. Currently this is set up to allow you to automatically import, and
automatically differentiate (both jvp and vjp) external C++ code into JAX. As
Enzyme is language-agnostic, this can be extended for arbitrary programming
Enzyme-JAX is a C++ project whose original aim was to integrate the Enzyme automatic differentiation tool [1] with JAX, enabling automatic differentiation of external C++ code within JAX. It has since expanded to incorporate Polygeist's [2] high performance raising, parallelization, cross compilation workflow, as well as numerous tensor, linear algerba, and communication optimizations. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. As Enzyme is language-agnostic, this can be extended for arbitrary programming
languages (Julia, Swift, Fortran, Rust, and even Python)!

You can use
Expand Down Expand Up @@ -77,3 +74,8 @@ Enzyme-Jax exposes a bunch of different tensor rewrites as MLIR passes in `src/e
```bash
bazel run :refresh_compile_commands
```

# References
[1] Moses, William, and Valentin Churavy. "Instead of rewriting foreign code for machine learning, automatically synthesize fast gradients." Advances in neural information processing systems 33 (2020): 12472-12485.

[2] Moses, William S., et al. "Polygeist: Raising C to polyhedral MLIR." 2021 30th International Conference on Parallel Architectures and Compilation Techniques (PACT). IEEE, 2021.