Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
## Pull Request Overview

### Summary
Describe the purpose and key changes of this PR.

### What Was Changed
### Summary // What Was Changed
-

### Quality Control
- [ ] Added/updated tests
- [ ] All tests pass locally
- [ ] Linter and type checks pass
- Linter, tests, and type checks all pass? (y/n):


### Related Issues
- Fixes: #
- Related: #


### Screenshots / Notes
<!-- Add screenshots or extra context if needed -->
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ jobs:
- name: Install dependencies
run: uv sync

- name: Run linting
run: uv run ruff check .
# Ambiguous reporting
# - name: Run linting
# run: uv run ruff check .

- name: Run type checking
run: uv run mypy src/
Expand Down
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# CU MIND
# Files
scratch.ipynb
test.json
test.json.bak

# Directory
checkpoints/
logs/
wandb/


# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
9 changes: 9 additions & 0 deletions Release_Notes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ CuMind Release Notes
-------------------------------------------------------------------
Document all technical changes introduced in this release as concise bullet points below.


v0.1.95 (2025-08-03)
----------------------
Tarek Ibrahim (68++)
- Added Weights & Biases (wandb) integration for experiment tracking
- Trainer improvements: debug mode, configurable number of batches, multi-backend support
- Added TqdmSink for progress bar logging
- Minor code cleanups and test updates

v0.1.8 (2025-07-18)
----------------------
Tarek Ibrahim (68)
Expand Down
41 changes: 26 additions & 15 deletions configuration.json
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
{
"CuMind": {
"networks": {
"hidden_dim": 128
"hidden_state_dim": 128
},
"representation": {
"type": "cumind.core.resnet.ResNet",
"num_blocks": 2,
"num_hidden_layers": 2,
"conv_channels": 32,
"seed": 42
},
"dynamics": {
"type": "cumind.core.mlp.MLPWithEmbedding",
"num_blocks": 2,
"hidden_dim": 128,
"num_hidden_layers": 2,
"seed": 42
},
"prediction": {
"type": "cumind.core.mlp.MLPDual",
"hidden_dim": 128,
"num_hidden_layers": 2,
"seed": 42
},
"memory": {
"type": "cumind.data.memory.MemoryBuffer",
"type": "cumind.data.memory.PrioritizedMemoryBuffer",
"capacity": 2000,
"min_size": 100,
"min_size": 200,
"min_pct": 0.1,
"per_alpha": 0.6,
"per_epsilon": 1e-06,
"per_beta": 0.4
"alpha": 0.6,
"epsilon": 1e-06,
"beta": 0.4
},
"training": {
"optimizer": "optax.adamw",
"batch_size": 64,
"learning_rate": 0.01,
"num_batches": 1,
"learning_rate": 0.001,
"weight_decay": 0.0001,
"target_update_frequency": 250,
"target_update_frequency": 100,
"checkpoint_interval": 50,
"num_episodes": 1220,
"num_episodes": 2000,
"train_frequency": 2,
"checkpoint_root_dir": "checkpoints"
"checkpoint_dir": "checkpoints",
"debug": false
},
"mcts": {
"num_simulations": 25,
Expand All @@ -47,7 +52,8 @@
"env": {
"name": "CartPole-v1",
"action_space_size": 2,
"observation_shape": [4]
"observation_shape": [4],
"max_episode_steps": 500
},
"selfplay": {
"num_unroll_steps": 5,
Expand All @@ -60,13 +66,18 @@
"target": "float32"
},
"logging": {
"wandb": false,
"title": "spamEggs",
"tags": ["foo", "bar"],
"dir": "logs",
"level": "INFO",
"console": true,
"timestamps": false,
"timestamps": true,
"tqdm": false
},
"device": "cpu",
"seed": 42
"seed": 42,
"validate": true,
"multi_device": false
}
}
9 changes: 2 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# =========================
[project]
name = "cumind"
version = "0.1.8"
version = "0.1.95"
description = "A JAX-based CuMind is a JAX-based RL framework inspired by Google DeepMind. Achieve's superhuman performance in complex domains without pretraining nor prior knowledge of their rules."
readme = "README.md"
requires-python = ">=3.12"
Expand All @@ -17,7 +17,6 @@ dependencies = [
"optax>=0.2.5",
"chex>=0.1.89",
"wandb>=0.20.1",
"tensorboard>=2.19.0",
"ipython>=9.3.0",
"ipykernel>=6.29.5",
]
Expand All @@ -29,11 +28,7 @@ dependencies = [
# - chex: Testing and assertion utilities for JAX

# Note: JAX CPU version is used by default
# To install other JAX versions, use:
# pip install -U jax[cuda12] # for NVIDIA GPUs with CUDA 12
# pip install -U jax[tpu] # for TPU
# pip install -U jax[rocm] # for AMD GPUs
# pip install -U jax[metal] # for Apple Silicon
# LOOK AT CONFIG.PY LINE 385 VALIDATING DEVICE.


# =========================
Expand Down
12 changes: 9 additions & 3 deletions src/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@

def main() -> None:
"""Main function for running the CartPole example."""
cfg.load("configuration.json")
ckpt = train()
timestamp, checkpoint_dir = cfg.load("test.json")
print(f"Run UUID: {timestamp}")

train()
log.info(f"Training completed in {log.elapsed()}.")
inference(ckpt)

latest_ckpt = f"{checkpoint_dir}/episode_{cfg.training.num_episodes:05d}.pkl"

log.open()
inference(latest_ckpt, 500)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/cumind/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""CuMind: A modular reinforcement learning framework."""

__version__ = "0.1.8"
__version__ = "0.1.95"

# Most commonly used components
from .agent import Agent, inference, train
Expand Down
31 changes: 18 additions & 13 deletions src/cumind/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from cumind.core.network import CuMindNetwork
from cumind.utils.config import cfg
from cumind.utils.logger import log
from cumind.utils.prng import key


class Agent:
Expand All @@ -34,9 +33,6 @@ def __init__(self, existing_state: Optional[Dict[str, Any]] = None):
with jax.default_device(self.device):
self.network = CuMindNetwork(representation_network=cfg.representation(), dynamics_network=cfg.dynamics(), prediction_network=cfg.prediction())

log.info("Creating target network.")
self.target_network = nnx.clone(self.network)

log.info(f"Setting up AdamW optimizer with learning rate {cfg.training.learning_rate} and weight decay {cfg.training.weight_decay}")
self.optimizer = optax.adamw(learning_rate=cfg.training.learning_rate, weight_decay=cfg.training.weight_decay)

Expand All @@ -46,6 +42,9 @@ def __init__(self, existing_state: Optional[Dict[str, Any]] = None):
else:
log.info("Initializing new optimizer state.")
self.optimizer_state = self.optimizer.init(nnx.state(self.network, nnx.Param))
# Ensure target network is properly initialized
log.info("Initializing target prediction network.")
self.network.update_target_prediction_network(hard=True)

self.mcts = MCTS(self.network)
log.info("Agent initialization complete.")
Expand All @@ -60,6 +59,11 @@ def select_action(self, observation: np.ndarray, training: bool = False) -> Tupl
Returns:
A tuple containing the selected action index and the MCTS policy probabilities.
"""
if cfg.training.debug:
num_actions = cfg.env.action_space_size
action_probs = np.ones(num_actions, dtype=np.float32) / num_actions
action_idx = int(np.random.choice(num_actions))
return action_idx, action_probs
log.debug(f"Selecting action. Training mode: {training}")

obs_tensor = jax.device_put(jnp.array(observation)[None], self.device) # [None] adds batch dimension
Expand All @@ -69,22 +73,23 @@ def select_action(self, observation: np.ndarray, training: bool = False) -> Tupl

# Use MCTS to get action probabilities
action_probs = self.mcts.search(root_hidden_state=hidden_state_array, add_noise=training)

# Take best action
action_idx = int(np.argmax(action_probs))
"""
if training:
# Sample action from probabilities
action_idx = int(jax.random.choice(key.get(), len(action_probs), p=action_probs))
else:
# Take best action
action_idx = int(np.argmax(action_probs))

log.debug(f"Selected action: {action_idx}")
"""
# log.debug(f"Selected action: {action_idx}")
return int(action_idx), action_probs

def update_target_network(self) -> None:
"""Update the target network's weights with the main network's weights."""
log.debug("Updating target network.")
online_params = nnx.state(self.network, nnx.Param)
nnx.update(self.target_network, online_params)
"""Update the target prediction network's weights with the main network's weights."""
log.debug("Updating target prediction network.")
self.network.update_target_prediction_network(hard=False, tau=0.01)

def save_state(self) -> Dict[str, Any]:
"""Get the current state of the agent for checkpointing.
Expand All @@ -108,6 +113,6 @@ def load_state(self, state: Dict[str, Any]) -> None:
nnx.update(self.network, state["network_state"])
self.optimizer_state = state["optimizer_state"]

log.info("Updating target network after loading state.")
self.update_target_network()
log.info("Updating target prediction network after loading state.")
self.network.update_target_prediction_network(hard=True)
log.info("Agent state loaded successfully.")
13 changes: 6 additions & 7 deletions src/cumind/agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,34 @@

def train() -> str:
"""Train the agent on a given environment."""
env = gym.make(cfg.env.name)
env = gym.make(id=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps)

agent = Agent()
memory_buffer = cfg.memory()
trainer = Trainer(agent, memory_buffer)

trainer.run_training_loop(env)
trainer.train(env)

env.close() # type: ignore

return trainer.checkpoint_dir


def inference(checkpoint_file: str) -> None:
def inference(checkpoint_file: str, num_episodes: int) -> None:
"""Run inference with a trained agent from a checkpoint."""
log.info("\nStarting inference.")

if not os.path.isfile(checkpoint_file):
log.error(f"Checkpoint file not found: {checkpoint_file}")
return
raise RuntimeError(f"Checkpoint file not found: {checkpoint_file}")

log.info(f"Loading agent from: {checkpoint_file}")

inference_agent = Agent()
state = load_checkpoint(checkpoint_file)
inference_agent.load_state(state)

env = gym.make(cfg.env.name, render_mode="human")
for episode in range(500):
env = gym.make(id=cfg.env.name, max_episode_steps=cfg.env.max_episode_steps, render_mode="human")
for episode in range(num_episodes):
obs, _ = env.reset()
done = False
total_reward = 0.0
Expand Down
Loading