A one file, fast, fully JIT-compilable GridWorld environment with PPO implementation using JAX and Flax's experimental NNX library.
NNX-Gridworld provides a pure JAX implementation of a simple grid world environment with PPO training. The entire pipeline is JIT-compiled for maximum performance, making it ideal for rapid experimentation and research.
Trained agent navigating to the goal:
- One file
- Pure JAX Implementation: Fully vectorized and JIT-compilable environment
- Flax NNX: Leverages the NNX API for cleaner neural network definitions
- Two Learning Modes:
- State-based learning (x,y coordinates)
- Vision-based learning (raw pixel observations)
- High Performance: Entire training loop runs on GPU/TPU with XLA compilation
- Simple & Extensible: Clean codebase designed for easy modification
# Clone the repository
git clone git@github.com:Aneeshers/NNX-Gridworld.git
cd NNX-Gridworld
# Install dependencies
pip install jax jaxlib flax optaxTrain an agent using (x,y) coordinate observations:
python control/learning/ppo/ppo_simple_grid.pyTrain an agent using raw pixel observations:
python control/learning/ppo/vision_ppo_simple_grid.pyThe GridWorld environment is a simple navigation task where:
- Agent: Starts at a random position
- Goal: Placed at a random position
- Objective: Navigate to the goal in minimum steps
- Actions: Up, Down, Left, Right
- Reward: +1 for reaching goal, -0.01 per step
The PPO agent learns to solve the gridworld task efficiently:
- Input: Agent and goal positions (x, y coordinates)
- Network: 2-layer MLP (64 hidden units each)
- Output: Action probabilities and value estimate
- Input: 8x8 RGB image of the grid
- Network: CNN encoder + MLP heads
- Output: Action probabilities and value estimate
The entire training loop is JIT-compiled for maximum performance:
@jax.jit
def train_step(state, batch):
# Fully compiled PPO update
...Parallel environment execution across batches:
@jax.jit
def rollout(env_state, policy_state, key, num_steps):
# Vectorized rollout collection
...Clean, intuitive model definition with Flax NNX:
class ActorCritic(nnx.Module):
def __init__(self, obs_dim: int, act_dim: int):
self.shared = nnx.Sequential(
nnx.Linear(obs_dim, 64),
nnx.relu,
nnx.Linear(64, 64),
nnx.relu
)
self.actor = nnx.Linear(64, act_dim)
self.critic = nnx.Linear(64, 1)Training 5000 epochs with:
- 32 parallel environments
- 32 steps per rollout
- 5 PPO epochs per update
- Runs in ~15 seconds on a modern GPU
NNX-Gridworld/
├── control/
│ └── learning/
│ └── ppo/
│ ├── ppo_simple_grid.py # State-based PPO
│ ├── vision_ppo_simple_grid.py # Vision-based PPO
│ ├── ppo_training_results.png # Training curves
│ └── state_rollout_b3.gif # Expert demo
└── README.md

