From 952fa3a3227655e55b9717c9c49e213190758103 Mon Sep 17 00:00:00 2001 From: "Tristan F." Date: Thu, 26 Sep 2024 14:34:03 -0700 Subject: [PATCH] feat: sprouts --- Cargo.lock | 2 + book/src/performance.md | 4 +- crates/game-solver/src/lib.rs | 6 + crates/games-cli/src/main.rs | 3 +- crates/games/Cargo.toml | 2 +- crates/games/src/lib.rs | 13 +- crates/games/src/sprouts/mod.rs | 269 +++++++++++++++++++++++++- crates/games/src/util/cli/robot.rs | 4 +- crates/games/src/util/move_natural.rs | 9 +- 9 files changed, 297 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b8bc297..463941a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2508,6 +2508,8 @@ checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", "indexmap", + "serde", + "serde_derive", ] [[package]] diff --git a/book/src/performance.md b/book/src/performance.md index e545ea6..6b4d0e8 100644 --- a/book/src/performance.md +++ b/book/src/performance.md @@ -19,7 +19,6 @@ - If you want to use xxHash without parallelization, pass it to your hashmap by using `hasher: std::hash::BuildHasherDefault`. - You can disable xxhash by removing the `xxhash` feature. - More information about why you may want to do this can be found in the [hashing](#hashing) section -- ML-based move ordering with [candle](https://github.com/huggingface/candle/) - Parallelization with [rayon](https://github.com/rayon-rs/rayon) - Note that this is under the `rayon` feature flag. - TODO: Use Lazy SMP (currently this is using naive parallelization on the `move_scores` level) @@ -41,6 +40,9 @@ You can also use `game-solver`'s [reinforcement learning](./reinforcement_learni If possible, try to "guess" the score of a move, and sort the moves by that score. +Since `game-solver` uses principal variation search, if the first move in the move ordering is great, +this solver will generally work very fast. + ### Efficient Bitboards Use efficient bitboards - you can look at the examples for inspiration, but make sure your board representation is fast, and *preferably* doesn't need allocation. diff --git a/crates/game-solver/src/lib.rs b/crates/game-solver/src/lib.rs index a3af796..ed04e49 100644 --- a/crates/game-solver/src/lib.rs +++ b/crates/game-solver/src/lib.rs @@ -252,6 +252,12 @@ pub fn solve + Eq + Hash>( // we're trying to guess the score of the board via null windows while alpha < beta { + if let Some(token) = cancellation_token { + if token.load(Ordering::Relaxed) { + return Err(GameSolveError::CancellationTokenError); + } + } + let med = alpha + (beta - alpha) / 2; // do a [null window search](https://www.chessprogramming.org/Null_Window) diff --git a/crates/games-cli/src/main.rs b/crates/games-cli/src/main.rs index 8a87c24..5c36810 100644 --- a/crates/games-cli/src/main.rs +++ b/crates/games-cli/src/main.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use games::{ chomp::Chomp, domineering::Domineering, nim::Nim, order_and_chaos::OrderAndChaos, - reversi::Reversi, tic_tac_toe::TicTacToe, util::cli::play, Games, + reversi::Reversi, sprouts::Sprouts, tic_tac_toe::TicTacToe, util::cli::play, Games, }; /// `game-solver` is a solving utility that helps analyze various combinatorial games. @@ -27,6 +27,7 @@ fn main() -> Result<()> { Games::Nim(args) => play::(args.try_into().unwrap(), cli.plain), Games::Domineering(args) => play::>(args.try_into().unwrap(), cli.plain), Games::Chomp(args) => play::(args.try_into().unwrap(), cli.plain), + Games::Sprouts(args) => play::(args.try_into().unwrap(), cli.plain), }; Ok(()) diff --git a/crates/games/Cargo.toml b/crates/games/Cargo.toml index 4c3a167..467cea4 100644 --- a/crates/games/Cargo.toml +++ b/crates/games/Cargo.toml @@ -18,7 +18,7 @@ once_cell = "1.19.0" egui = { version = "0.28", optional = true } egui_commonmark = { version = "0.17.0", optional = true, features = ["macros"] } thiserror = "1.0.63" -petgraph = "0.6.5" +petgraph = { version = "0.6.5", features = ["serde-1"] } castaway = "0.2.3" ratatui = "0.28.1" owo-colors = "4.1.0" diff --git a/crates/games/src/lib.rs b/crates/games/src/lib.rs index 65467c3..5514f62 100644 --- a/crates/games/src/lib.rs +++ b/crates/games/src/lib.rs @@ -11,6 +11,7 @@ pub mod tic_tac_toe; use crate::{ chomp::ChompArgs, domineering::DomineeringArgs, nim::NimArgs, order_and_chaos::OrderAndChaosArgs, reversi::ReversiArgs, tic_tac_toe::TicTacToeArgs, + sprouts::SproutsArgs }; use clap::Subcommand; use once_cell::sync::Lazy; @@ -24,9 +25,10 @@ pub enum Games { Nim(NimArgs), Domineering(DomineeringArgs), Chomp(ChompArgs), + Sprouts(SproutsArgs) } -pub static DEFAULT_GAMES: Lazy<[Games; 6]> = Lazy::new(|| { +pub static DEFAULT_GAMES: Lazy<[Games; 7]> = Lazy::new(|| { [ Games::Reversi(Default::default()), Games::TicTacToe(Default::default()), @@ -34,6 +36,7 @@ pub static DEFAULT_GAMES: Lazy<[Games; 6]> = Lazy::new(|| { Games::Nim(Default::default()), Games::Domineering(Default::default()), Games::Chomp(Default::default()), + Games::Sprouts(Default::default()) ] }); @@ -46,6 +49,7 @@ impl Games { Self::Nim(_) => "Nim".to_string(), Self::Domineering(_) => "Domineering".to_string(), Self::Chomp(_) => "Chomp".to_string(), + Self::Sprouts(_) => "Sprouts".to_string() } } @@ -57,6 +61,7 @@ impl Games { Self::Nim(_) => include_str!("./nim/README.md"), Self::Domineering(_) => include_str!("./domineering/README.md"), Self::Chomp(_) => include_str!("./chomp/README.md"), + Self::Sprouts(_) => include_str!("./sprouts/README.md") } } @@ -100,6 +105,12 @@ impl Games { &mut cache, "crates/games/src/chomp/README.md" ), + Self::Sprouts(_) => egui_commonmark::commonmark_str!( + "sprouts", + ui, + &mut cache, + "crates/games/src/sprouts/README.md" + ), }; } } diff --git a/crates/games/src/sprouts/mod.rs b/crates/games/src/sprouts/mod.rs index 374255b..85abb32 100644 --- a/crates/games/src/sprouts/mod.rs +++ b/crates/games/src/sprouts/mod.rs @@ -1,11 +1,270 @@ #![doc = include_str!("./README.md")] -use game_solver::game::Game; -use petgraph::matrix_graph::MatrixGraph; +use std::{fmt::{Debug, Display}, hash::Hash, str::FromStr}; + +use anyhow::Error; +use clap::Args; +use game_solver::{game::{Game, StateType}, player::ImpartialPlayer}; +use itertools::Itertools; +use petgraph::{matrix_graph::{MatrixGraph, NodeIndex}, visit::{IntoEdgeReferences, IntoNodeIdentifiers}, Undirected}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::util::cli::move_failable; + +/// We aren't dealing with large sprout counts for now. +pub type SproutsIx = u8; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +pub struct SproutsMove { + from: NodeIndex, + to: NodeIndex +} + +impl Display for SproutsMove { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({} {})", self.from.index(), self.to.index()) + } +} + +type SproutsGraph = MatrixGraph<(), (), Undirected, Option<()>, SproutsIx>; #[derive(Clone)] -pub struct Sprouts(MatrixGraph<(), ()>); +pub struct Sprouts(SproutsGraph); + +// SproutsGraph, given that its vertices and edges are unlabelled, +// doesn't implement equality as that requires isomorphism checks. +// since we don't want these operations for reordering to be expensive, +// we simply check for equality as is. + +impl Hash for Sprouts { + fn hash(&self, state: &mut H) { + self.0.node_count().hash(state); + for edge in self.0.edge_references() { + edge.hash(state); + } + } +} + +impl PartialEq for Sprouts { + fn eq(&self, other: &Self) -> bool { + self.0.node_count() == other.0.node_count() + && self.0.edge_references().collect::>() == other.0.edge_references().collect::>() + } + + fn ne(&self, other: &Self) -> bool { + !self.eq(other) + } +} + +impl Eq for Sprouts {} + +impl Sprouts { + pub fn new(node_count: SproutsIx) -> Self { + let mut graph = SproutsGraph::default(); + + for _ in 0..node_count { + graph.add_node(()); + } + + Self(graph) + } +} + +#[derive(Error, Debug, Clone)] +pub enum SproutsMoveError { + #[error("chosen index {0} from move {1:?} is out of bounds.")] + MoveOutOfBounds(SproutsIx, SproutsMove), + #[error("chosen index {0} from move {1:?} references a dead sprout.")] + DeadSprout(SproutsIx, SproutsMove), + #[error("a move for {0:?} has already been made")] + SproutsConnected(SproutsMove) +} + +const MAX_SPROUTS: usize = 3; + +impl Game for Sprouts { + type Move = SproutsMove; + type Iter<'a> = std::vec::IntoIter; + + type Player = ImpartialPlayer; + type MoveError = SproutsMoveError; + + const STATE_TYPE: Option = Some(StateType::Normal); + + fn max_moves(&self) -> Option { + // TODO: i actually want to find what the proper paper is, but + // https://en.wikipedia.org/wiki/Sprouts_(game)#Maximum_number_of_moves + // is where this is from. + // TODO: use MAX_SPROUTS? + Some(3 * self.0.node_count() - 1) + } + + fn move_count(&self) -> usize { + self.0.edge_count() + } + + fn make_move(&mut self, m: &Self::Move) -> Result<(), Self::MoveError> { + // There already exists an edge here! + if self.0.has_edge(m.from, m.to) { + return Err(SproutsMoveError::SproutsConnected(m.clone())); + } + + // move index is out of bounds + { + if !self.0.node_identifiers().contains(&m.from) { + return Err(SproutsMoveError::MoveOutOfBounds( + m.from.index().try_into().unwrap(), + m.clone()) + ); + } + + if !self.0.node_identifiers().contains(&m.to) { + return Err(SproutsMoveError::MoveOutOfBounds( + m.to.index().try_into().unwrap(), + m.clone()) + ); + } + } + + // sprouts to use are dead + { + if self.0.edges(m.from).count() >= MAX_SPROUTS { + return Err(SproutsMoveError::DeadSprout( + m.from.index().try_into().unwrap(), + m.clone() + )); + } + + if self.0.edges(m.to).count() >= MAX_SPROUTS { + return Err(SproutsMoveError::DeadSprout( + m.to.index().try_into().unwrap(), + m.clone() + )); + } + } + + self.0.add_edge(m.from, m.to, ()); + + Ok(()) + } + + fn player(&self) -> Self::Player { + ImpartialPlayer::Next + } + + fn possible_moves(&self) -> Self::Iter<'_> { + let mut sprouts_moves = vec![]; + + for id in self.0.node_identifiers() { + let edge_count = self.0.edges(id).count(); + + // TODO: use MAX_SPROUTS for all values + match edge_count { + 0 | 1 => { + if !self.0.has_edge(id, id) { + sprouts_moves.push(SproutsMove { from: id, to: id }); + } + for sub_id in self.0.node_identifiers() { + if id >= sub_id { continue; } + if self.0.edges(sub_id).count() >= MAX_SPROUTS { continue; } + if self.0.has_edge(id, sub_id) { continue; } + sprouts_moves.push(SproutsMove { from: id, to: sub_id }) + } + }, + 2 => { + for sub_id in self.0.node_identifiers() { + if id >= sub_id { continue; } + if self.0.edges(sub_id).count() >= MAX_SPROUTS { continue; } + if self.0.has_edge(id, sub_id) { continue; } + sprouts_moves.push(SproutsMove { from: id, to: sub_id }) + } + }, + MAX_SPROUTS => (), + _ => panic!("No node should have more than three edges") + } + } + + sprouts_moves.into_iter() + } + + fn state(&self) -> game_solver::game::GameState { + Self::STATE_TYPE.unwrap().state(self) + } +} + +impl Debug for Sprouts { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let references = self.0.edge_references().collect::>(); + + writeln!(f, "graph of vertices count {}", references.len())?; + + if references.is_empty() { + return Ok(()); + } + + for (i, j, _) in references { + write!(f, "{}-{} ", i.index(), j.index())?; + } + writeln!(f)?; + + Ok(()) + } +} + +impl Display for Sprouts { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(&self, f) + } +} + +/// Analyzes Sprouts. +/// +#[doc = include_str!("./README.md")] +#[derive(Args, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct SproutsArgs { + /// The amount of sprouts (nodes) + /// to start off with. + starting_sprouts: SproutsIx, + /// Sprouts moves, ordered as i1-j1 i2-j2 ... + #[arg(value_parser = clap::value_parser!(SproutsMove))] + moves: Vec +} + +impl Default for SproutsArgs { + fn default() -> Self { + Self { + starting_sprouts: 6, + moves: vec![] + } + } +} + +impl TryFrom for Sprouts { + type Error = Error; + + fn try_from(args: SproutsArgs) -> Result { + let mut game = Sprouts::new(args.starting_sprouts); + + for sprouts_move in args.moves { + move_failable(&mut game, &sprouts_move)?; + } + + Ok(game) + } +} + +impl FromStr for SproutsMove { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let components = s.split("-").collect::>(); -// impl Game for Sprouts { + assert_eq!(components.len(), 2, "a move shouldn't connect more than two sprouts"); -// } + Ok(SproutsMove { + from: str::parse::(components[0])?.into(), + to: str::parse::(components[1])?.into() + }) + } +} diff --git a/crates/games/src/util/cli/robot.rs b/crates/games/src/util/cli/robot.rs index 768fcd6..dff0e28 100644 --- a/crates/games/src/util/cli/robot.rs +++ b/crates/games/src/util/cli/robot.rs @@ -1,5 +1,5 @@ use game_solver::{ - game::{score_to_outcome, Game, GameScoreOutcome}, + game::Game, par_move_scores, player::{ImpartialPlayer, TwoPlayer}, }; @@ -9,7 +9,7 @@ use std::{ hash::Hash, }; -use crate::util::{cli::report::scores::show_scores, move_score::normalize_move_scores}; +use crate::util::cli::report::scores::show_scores; pub fn robotic_output< T: Game diff --git a/crates/games/src/util/move_natural.rs b/crates/games/src/util/move_natural.rs index d43a1f7..31b8835 100644 --- a/crates/games/src/util/move_natural.rs +++ b/crates/games/src/util/move_natural.rs @@ -1,5 +1,6 @@ use std::{fmt::Display, iter, str::FromStr}; +use anyhow::{anyhow, Error}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use serde_big_array::BigArray; @@ -8,7 +9,7 @@ use serde_big_array::BigArray; pub struct NaturalMove(#[serde(with = "BigArray")] pub [usize; LENGTH]); impl FromStr for NaturalMove { - type Err = String; + type Err = Error; fn from_str(s: &str) -> Result { assert!(LENGTH > 0, "Length must be greater than 0"); @@ -17,7 +18,7 @@ impl FromStr for NaturalMove { let numbers = s.split('-').collect::>(); if numbers.len() != LENGTH { - return Err(format!( + return Err(anyhow!( "Must be {} numbers separated by a hyphen ({})", LENGTH, iter::repeat("x").take(LENGTH).join("-") @@ -32,7 +33,7 @@ impl FromStr for NaturalMove { if let Some((position, _)) = numbers.iter().find_position(|x| x.is_err()) { let ordinal = ordinal::Ordinal(position + 1).to_string(); - return Err(format!("The {} number is not a number.", ordinal)); + return Err(anyhow!("The {} number is not a number.", ordinal)); } numbers @@ -40,7 +41,7 @@ impl FromStr for NaturalMove { .map(|x| x.clone().unwrap()) .collect::>() .try_into() - .map_err(|_| "Could not convert Vec to fixed array; this is a bug.".to_string()) + .map_err(|_| anyhow!("Could not convert Vec to fixed array; this is a bug.")) .map(NaturalMove) } }