From e12d2b9c3ca9143a84cd64329f4a0441c05bff7b Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Sun, 3 Nov 2024 20:46:21 -0800 Subject: [PATCH] chore: drop reinforcement --- crates/game-solver/src/disjoint_game.rs | 21 +- crates/game-solver/src/reinforcement/agent.rs | 20 -- crates/game-solver/src/reinforcement/mod.rs | 278 ------------------ crates/game-solver/src/reinforcement/state.rs | 27 -- .../src/reinforcement/strategy/explore.rs | 26 -- .../src/reinforcement/strategy/mod.rs | 2 - .../src/reinforcement/strategy/terminate.rs | 7 - 7 files changed, 17 insertions(+), 364 deletions(-) delete mode 100644 crates/game-solver/src/reinforcement/agent.rs delete mode 100644 crates/game-solver/src/reinforcement/mod.rs delete mode 100644 crates/game-solver/src/reinforcement/state.rs delete mode 100644 crates/game-solver/src/reinforcement/strategy/explore.rs delete mode 100644 crates/game-solver/src/reinforcement/strategy/mod.rs delete mode 100644 crates/game-solver/src/reinforcement/strategy/terminate.rs diff --git a/crates/game-solver/src/disjoint_game.rs b/crates/game-solver/src/disjoint_game.rs index fa88126..d4e32b1 100644 --- a/crates/game-solver/src/disjoint_game.rs +++ b/crates/game-solver/src/disjoint_game.rs @@ -8,8 +8,10 @@ use crate::{game::{Game, Normal, NormalImpartial}, player::ImpartialPlayer}; /// two impartial normal combinatorial games. /// /// Since `Game` isn't object safe, we use `dyn Any` internally with downcast safety. +/// +/// We restrict games to being normal impartial to force implementation of the marker trait. #[derive(Clone)] -pub struct DisjointImpartialNormalGame { +pub struct DisjointImpartialNormalGame { left: L, right: R } @@ -31,9 +33,20 @@ pub enum DisjointMoveError { type LeftMoveMap = Box::Move) -> DisjointMove>; type RightMoveMap = Box::Move) -> DisjointMove>; -impl Normal for DisjointImpartialNormalGame {} -impl NormalImpartial for DisjointImpartialNormalGame {} -impl Game for DisjointImpartialNormalGame { +impl< + L: Game + Debug + NormalImpartial + 'static, + R: Game + Debug + NormalImpartial + 'static +> Normal for DisjointImpartialNormalGame {} + +impl< + L: Game + Debug + NormalImpartial + 'static, + R: Game + Debug + NormalImpartial + 'static +> NormalImpartial for DisjointImpartialNormalGame {} + +impl< + L: Game + Debug + NormalImpartial + 'static, + R: Game + Debug + NormalImpartial + 'static +> Game for DisjointImpartialNormalGame { type Move = DisjointMove; type Iter<'a> = Interleave< Map<::Iter<'a>, LeftMoveMap>, diff --git a/crates/game-solver/src/reinforcement/agent.rs b/crates/game-solver/src/reinforcement/agent.rs deleted file mode 100644 index 313cca9..0000000 --- a/crates/game-solver/src/reinforcement/agent.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::state::State; - -/// An `Agent` is something which hold a certain state, and is able to take actions from that -/// state. After taking an action, the agent arrives at another state. -pub trait Agent { - /// Returns the current state of this `Agent`. - fn current_state(&self) -> &S; - /// Takes the given action, possibly mutating the current `State`. - fn take_action(&mut self, action: &S::A); - /// Takes a random action from the set of possible actions from this `State`. The default - /// implementation uses [`State::random_action()`](trait.State.html#method.random_action) to - /// determine the action to be taken. - fn pick_random_action(&mut self) -> S::A { - let action = self.current_state().random_action(); - - self.take_action(&action); - - action - } -} diff --git a/crates/game-solver/src/reinforcement/mod.rs b/crates/game-solver/src/reinforcement/mod.rs deleted file mode 100644 index f7963ab..0000000 --- a/crates/game-solver/src/reinforcement/mod.rs +++ /dev/null @@ -1,278 +0,0 @@ -use dfdx::prelude::*; - -#[cfg(feature = "save")] -use dfdx::safetensors::SafeTensorError; - -use self::agent::Agent; -use self::state::State; - -use self::strategy::explore::ExplorationStrategy; -use self::strategy::terminate::TerminationStrategy; - -pub mod agent; -pub mod state; -pub mod strategy; - -const BATCH: usize = 64; - -#[derive(Default, Clone, Debug, Sequential)] -#[built(QNetwork)] -struct QNetWorkConfig { - linear1: LinearConstConfig, - act1: ReLU, - linear2: LinearConstConfig, - act2: ReLU, - linear3: LinearConstConfig, -} - -/// An `DQNAgentTrainer` can be trained for using a certain [Agent](mdp/trait.Agent.html). After -/// training, the `DQNAgentTrainer` contains learned knowledge about the process, and can be queried -/// for this. For example, you can ask the `DQNAgentTrainer` the expected values of all possible -/// actions in a given state. -/// -/// The code is partially taken from . -/// and . -pub struct DQNAgentTrainer< - S, - const STATE_SIZE: usize, - const ACTION_SIZE: usize, - const INNER_SIZE: usize, -> where - S: State + Into<[f32; STATE_SIZE]>, - S::A: Into<[f32; ACTION_SIZE]>, - S::A: From<[f32; ACTION_SIZE]>, -{ - /// The [discount factor](https://en.wikipedia.org/wiki/Q-learning#Discount_factor) for future rewards. - gamma: f32, - /// The Q-network that is being trained. - q_network: QNetwork, - /// The target Q-network that is used to compute the target Q-values. - target_q_net: QNetwork, - /// The optimizer that is used to train the Q-network. - sgd: Sgd, f32, AutoDevice>, - dev: AutoDevice, - /// Preserves the type of the state. - phantom: std::marker::PhantomData, -} - -impl - DQNAgentTrainer -where - S: State + Into<[f32; STATE_SIZE]>, - S::A: Into<[f32; ACTION_SIZE]>, - S::A: From<[f32; ACTION_SIZE]>, -{ - /// Creates a new `DQNAgentTrainer` with the given parameters. - /// - /// # Arguments - /// - /// * `gamma` - The [discount factor](https://en.wikipedia.org/wiki/Q-learning#Discount_factor) for future rewards. - /// * `learning_rate` - The learning rate for the stochastic gradient descent optimizer. - /// - /// # Returns - /// - /// A new `DQNAgentTrainer` with the given parameters. - /// - pub fn new( - gamma: f32, - learning_rate: f64, - ) -> DQNAgentTrainer { - let dev = AutoDevice::default(); - - // initialize model - let architecture: QNetWorkConfig = Default::default(); - let q_net = dev.build_module::(architecture); - let target_q_net = q_net.clone(); - - // initialize optimizer - let sgd = Sgd::new( - &q_net, - SgdConfig { - lr: learning_rate, - momentum: Some(Momentum::Nesterov(0.9)), - weight_decay: None, - }, - ); - - DQNAgentTrainer { - gamma, - q_network: q_net, - target_q_net, - sgd, - dev, - phantom: std::marker::PhantomData, - } - } - - /// Fetches the learned value for the given `Action` in the given `State`, or `None` if no - /// value was learned. - pub fn expected_value(&self, state: &S) -> [f32; ACTION_SIZE] { - let state_: [f32; STATE_SIZE] = (state.clone()).into(); - let states: Tensor, f32, _> = - self.dev.tensor(state_).normalize::>(0.001); - let actions = self.target_q_net.forward(states).nans_to(0f32); - actions.array() - } - - /// Returns a clone of the entire learned state to be saved or used elsewhere. - pub fn export_learned_values( - &self, - ) -> QNetwork { - self.learned_values().clone() - } - - // Returns a reference to the learned state. - pub fn learned_values( - &self, - ) -> &QNetwork { - &self.q_network - } - - /// Imports a model, completely replacing any learned progress - pub fn import_model( - &mut self, - model: QNetwork, - ) { - self.q_network.clone_from(&model); - self.target_q_net.clone_from(&self.q_network); - } - - /// Returns the best action for the given `State`, or `None` if no values were learned. - pub fn best_action(&self, state: &S) -> Option { - let target = self.expected_value(state); - - Some(target.into()) - } - - #[allow(clippy::boxed_local)] - pub fn train_dqn( - &mut self, - states: [[f32; STATE_SIZE]; BATCH], - actions: [[f32; ACTION_SIZE]; BATCH], - next_states: [[f32; STATE_SIZE]; BATCH], - rewards: [f32; BATCH], - dones: [bool; BATCH], - ) { - self.target_q_net.clone_from(&self.q_network); - let mut grads = self.q_network.alloc_grads(); - - let dones: Tensor, f32, _> = - self.dev.tensor(dones.map(|d| if d { 1f32 } else { 0f32 })); - let rewards = self.dev.tensor(rewards); - - // Convert to tensors and normalize the states for better training - let states: Tensor, f32, _> = - self.dev.tensor(states).normalize::>(0.001); - - // Convert actions to tensors and get the max action for each batch - let actions: Tensor, usize, _> = self.dev.tensor(actions.map(|a| { - let mut max_idx = 0; - let mut max_val = 0f32; - for (i, v) in a.iter().enumerate() { - if *v > max_val { - max_val = *v; - max_idx = i; - } - } - max_idx - })); - - // Convert to tensors and normalize the states for better training - let next_states: Tensor, f32, _> = - self.dev.tensor(next_states).normalize::>(0.001); - - // Compute the estimated Q-value for the action - for _step in 0..20 { - let q_values = self.q_network.forward(states.trace(grads)); - - let action_qs = q_values.select(actions.clone()); - - // targ_q = R + discount * max(Q(S')) - // curr_q = Q(S)[A] - // loss = huber(curr_q, targ_q, 1) - let next_q_values = self.target_q_net.forward(next_states.clone()); - let max_next_q = next_q_values.max::, _>(); - let target_q = (max_next_q * (-dones.clone() + 1.0)) * self.gamma + rewards.clone(); - - let loss = huber_loss(action_qs, target_q, 1.0); - - grads = loss.backward(); - - // update weights with optimizer - self.sgd - .update(&mut self.q_network, &grads) - .expect("Unused params"); - self.q_network.zero_grads(&mut grads); - } - self.target_q_net.clone_from(&self.q_network); - } - - /// Trains this [DQNAgentTrainer] using the given [ExplorationStrategy] and - /// [Agent] until the [TerminationStrategy] decides to stop. - pub fn train( - &mut self, - agent: &mut dyn Agent, - termination_strategy: &mut dyn TerminationStrategy, - exploration_strategy: &dyn ExplorationStrategy, - ) { - loop { - // Initialize batch - let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; - let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH]; - let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH]; - let mut rewards: [f32; BATCH] = [0.0; BATCH]; - let mut dones = [false; BATCH]; - - let mut s_t_next = agent.current_state(); - - for i in 0..BATCH { - let s_t = agent.current_state().clone(); - let action = exploration_strategy.pick_action(agent); - - // current action value - s_t_next = agent.current_state(); - let r_t_next = s_t_next.reward(); - - states[i] = s_t.into(); - actions[i] = action.into(); - next_states[i] = (*s_t_next).clone().into(); - rewards[i] = r_t_next as f32; - - if termination_strategy.should_stop(s_t_next) { - dones[i] = true; - break; - } - } - - // train the network - self.train_dqn(states, actions, next_states, rewards, dones); - - // terminate if the agent is done - if termination_strategy.should_stop(s_t_next) { - break; - } - } - } - - #[cfg(feature = "save")] - pub fn save(&self, path: &str) -> Result<(), SafeTensorError> { - Ok(self.q_network.save_safetensors(&path)?) - } - - #[cfg(feature = "save")] - pub fn load(&mut self, path: &str) -> Result<(), SafeTensorError> { - Ok(self.q_network.load_safetensors(&path)?) - } -} - -impl Default - for DQNAgentTrainer -where - S: State + Into<[f32; STATE_SIZE]>, - S::A: Into<[f32; ACTION_SIZE]>, - S::A: From<[f32; ACTION_SIZE]>, -{ - fn default() -> Self { - Self::new(0.99, 1e-3) - } -} diff --git a/crates/game-solver/src/reinforcement/state.rs b/crates/game-solver/src/reinforcement/state.rs deleted file mode 100644 index 767c707..0000000 --- a/crates/game-solver/src/reinforcement/state.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::hash::Hash; - -use rand::seq::SliceRandom; - -/// A `State` is something which has a reward, and has a certain set of actions associated with it. -/// The type of the actions must be defined as the associated type `A`. -pub trait State: Eq + Hash + Clone { - /// Action type associate with this `State`. - type A: Eq + Hash + Clone; - - /// The reward for when an `Agent` arrives at this `State`. - /// - /// Rewards are relative to each other, and are traditionally smaller integers. - fn reward(&self) -> f64; - /// The set of actions that can be taken from this `State`, to arrive in another `State`. - fn actions(&self) -> Vec; - /// Selects a random action that can be taken from this `State`. The default implementation - /// takes a uniformly distributed random action from the defined set of actions. You may want - /// to improve the performance by only generating the necessary action. - fn random_action(&self) -> Self::A { - let actions = self.actions(); - actions - .choose(&mut rand::thread_rng()) - .cloned() - .expect("No actions available; perhaps use the SinkStates termination strategy?") - } -} diff --git a/crates/game-solver/src/reinforcement/strategy/explore.rs b/crates/game-solver/src/reinforcement/strategy/explore.rs deleted file mode 100644 index 4f95d4c..0000000 --- a/crates/game-solver/src/reinforcement/strategy/explore.rs +++ /dev/null @@ -1,26 +0,0 @@ -use crate::reinforcement::{agent::Agent, state::State}; - -/// Trait for exploration strategies. An exploration strategy decides, based on an `Agent`, which -/// action to take next. -pub trait ExplorationStrategy { - /// Selects the next action to take for this `Agent`. - fn pick_action(&self, agent: &mut dyn Agent) -> S::A; -} - -/// The random exploration strategy. -/// This strategy always takes a random action, as defined for the -/// Agent by -/// Agent::take_random_action() -pub struct RandomExploration; - -impl Default for RandomExploration { - fn default() -> Self { - Self - } -} - -impl ExplorationStrategy for RandomExploration { - fn pick_action(&self, agent: &mut dyn Agent) -> S::A { - agent.pick_random_action() - } -} diff --git a/crates/game-solver/src/reinforcement/strategy/mod.rs b/crates/game-solver/src/reinforcement/strategy/mod.rs deleted file mode 100644 index 05498d0..0000000 --- a/crates/game-solver/src/reinforcement/strategy/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod explore; -pub mod terminate; diff --git a/crates/game-solver/src/reinforcement/strategy/terminate.rs b/crates/game-solver/src/reinforcement/strategy/terminate.rs deleted file mode 100644 index d9f8259..0000000 --- a/crates/game-solver/src/reinforcement/strategy/terminate.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::reinforcement::state::State; - -/// A termination strategy decides when to end training. -pub trait TerminationStrategy { - /// If `should_stop` returns `true`, training will end. - fn should_stop(&mut self, state: &S) -> bool; -}