Slope is a small automatic differentation (AD) engine, focused on machine learning (ML), supporting forward, reverse and higher-order AD.
This project is designed to be a small, hackable and educational AD engine focused on ML, yet able to do things end-to-end from training to deployment, instead of just some simple toy examples.
Tensor semantics are similar to Pytorch, functional API is similar to JAX, tensor operators code is heavily derived from tinygrad.
Example:
import slope
def f(x):
y = x * 2.0
return y.sum()
x = slope.tensor([1.,2.,3.])
gf_x = slope.grad(f)(x)
print(f"{gf_x=}")
gf_x=<Tensor: val=
[2. 2. 2.]
shape=(3,), dtype=float32, device='cpu:0'>
pip install slope-ad
or latest from main branch:
git clone https://github.com/radenmuaz/slope-ad
cd slope
pip install -e .
or you can just copy src/slope
to your projects.
-
Small (?)
- <3000 lines of core code slope/core.py, after formatted with
black src --line-length 140
- <3000 lines of core code slope/core.py, after formatted with
-
Functional API for forward-mode, reverse-mode, and higher-order AD, like in JAX:
grad vjp jvp jit vmap
register_node tree_flatten tree_unflatten
-
Just-in-time compilation, where code is compiled to these supported backends running on either CPU, CUDA and Metal:
- ONNX Runtime (ONNX graph)
- OpenXLA IREE (StableHLO MLIR)
- NumPy (Python code)
-
Training and inference, examples:
-
Operators and procedures system
- 33 core operators defined in slope/operators.py
- Unary:
exp log sin sqrt invert cast stop_gradient
- Binary:
add mul sub div pow equal less greater maximum
- Reduce:
sum max
- Shape:
reshape expand permute slice pad flip cat
- Init:
full arange random_normal random_uniform
- GeneralReduce:
matmul conv gather_nd scatter_nd
- Unary:
- Composite operators system with "procedures" slope/procedures.py
- For defining Tensor functions composed with core operators, e.g.
x.cos()
, wheredef cos(x): return (math.pi/2 - x).sin()
x.conv_transpose(w)
: wheredef conv_transpose(x, w, ... ): ...
is a very long function.
- Procedures are exposed with
Tensor.procedure_name(*args)
syntax.
- For defining Tensor functions composed with core operators, e.g.
- 33 core operators defined in slope/operators.py
-
Extensible
- Add new backend by defining implementation translations slope/backends
- Define new modules with NN module slope/nn.py
Docs are available online at radenmuaz.github.io/slope-ad API reference: radenmuaz.github.io/slope-ad/api
Quickstart: How Tensors work, how to write and jit compile functions and train something.
NN Training: Train MLP on MNIST with slope.nn module
Internals Walkthrough: Understand the core of SlopeAD (hint: like JAX). Useful if you want to start contributing to SlopeAD
Extending SlopeAD: Add new backend, operators, procedures. Modify the core functions.
Open a PR, things on the roadmap below need to be done.
- Docs
- Symbolic shape inference
- Dynamic shape jit
- Optimizer filter frozen params
- vmap vjp and jvp to compute jacobian and hessian
- iree backend currently has fixed seed random, implement threefry and JAX-like random
- make things fast
- llama (gpt) training
- whisper inference
- core tests, operators tests on all Trace types