Skip to content

Add seedable generate functions #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: stable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ maintenance = { status = "passively-maintained" }
travis-ci = { repository = "aatxe/markov" }

[features]
default = ["graph", "markgen", "yaml"]
default = ["graph", "markgen", "yaml", "seedable"]
graph = ["petgraph", "itertools"]
markgen = ["getopts"]
yaml = ["serde_yaml"]
seedable = ["linked-hash-map", "rand_chacha"]

[dependencies]
getopts = { version = "0.2.21", optional = true }
itertools = { version = "0.10.1", optional = true }
petgraph = { version = "0.6.0", optional = true }
linked-hash-map = { version = "0.5.6", features = ["serde_impl"], optional = true }
rand = "0.8.4"
rand_chacha = { version = "0.3.1" , optional = true}
serde = "1.0.130"
serde_derive = "1.0.130"
serde_yaml = { version = "0.8.20", optional = true }
120 changes: 115 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,25 @@ extern crate serde_derive;
#[cfg(feature = "yaml")]
extern crate serde_yaml;

#[cfg(feature = "seedable")]
extern crate linked_hash_map;

#[cfg(feature = "seedable")]
extern crate rand_chacha;

use std::borrow::ToOwned;

#[cfg(feature = "seedable")]
use linked_hash_map::Entry::{Occupied, Vacant};
#[cfg(feature = "seedable")]
use linked_hash_map::LinkedHashMap as HashMap;

#[cfg(feature = "seedable")]
use rand_chacha::{rand_core::SeedableRng, ChaCha12Rng};

#[cfg(not(feature = "seedable"))]
use std::collections::hash_map::Entry::{Occupied, Vacant};
#[cfg(not(feature = "seedable"))]
use std::collections::HashMap;
use std::fs::File;
use std::hash::Hash;
Expand Down Expand Up @@ -138,10 +155,31 @@ where
/// length of the generated collection, and `n` is the number of possible states from a given
/// state.
pub fn generate(&self) -> Vec<T> {
self.generate_base(&mut thread_rng())
}

/// Generates a collection of tokens from the chain. This operation is `O(mn)` where `m` is the
/// length of the generated collection, and `n` is the number of possible states from a given
/// state. Takes a custom generator for RNG.
#[cfg(feature = "seedable")]
pub fn generate_with_rng<R: Rng>(&self, rng: &mut R) -> Vec<T> {
self.generate_base(rng)
}

/// Generates a collection of tokens from the chain. This operation is `O(mn)` where `m` is the
/// length of the generated collection, and `n` is the number of possible states from a given
/// state. Takes a seed.
#[cfg(feature = "seedable")]
pub fn generate_with_seed(&self, seed: u64) -> Vec<T> {
let mut rng = ChaCha12Rng::seed_from_u64(seed);
self.generate_base(&mut rng)
}

fn generate_base<R: Rng>(&self, rng: &mut R) -> Vec<T> {
let mut ret = Vec::new();
let mut curs = vec![None; self.order];
loop {
let next = self.map[&curs].next();
let next = self.map[&curs].next(rng);
curs = curs[1..self.order].to_vec();
curs.push(next.clone());
if let Some(next) = next {
Expand All @@ -159,14 +197,37 @@ where
/// of possible states from a given state. This returns an empty vector if the token is not
/// found.
pub fn generate_from_token(&self, token: T) -> Vec<T> {
self.generate_from_token_base(token, &mut thread_rng())
}

/// Generates a collection of tokens from the chain, starting with the given token. This
/// operation is O(mn) where m is the length of the generated collection, and n is the number
/// of possible states from a given state. This returns an empty vector if the token is not
/// found. Takes a custom generator for RNG.
#[cfg(feature = "seedable")]
pub fn generate_from_token_with_rng<R: Rng>(&self, token: T, rng: &mut R) -> Vec<T> {
self.generate_from_token_base(token, rng)
}

/// Generates a collection of tokens from the chain, starting with the given token. This
/// operation is O(mn) where m is the length of the generated collection, and n is the number
/// of possible states from a given state. This returns an empty vector if the token is not
/// found. Takes a seed.
#[cfg(feature = "seedable")]
pub fn generate_from_token_with_seed(&self, token: T, seed: u64) -> Vec<T> {
let mut rng = ChaCha12Rng::seed_from_u64(seed);
self.generate_from_token_base(token, &mut rng)
}

fn generate_from_token_base<R: Rng>(&self, token: T, rng: &mut R) -> Vec<T> {
let mut curs = vec![None; self.order - 1];
curs.push(Some(token.clone()));
if !self.map.contains_key(&curs) {
return Vec::new();
}
let mut ret = vec![token];
loop {
let next = self.map[&curs].next();
let next = self.map[&curs].next(rng);
curs = curs[1..self.order].to_vec();
curs.push(next.clone());
if let Some(next) = next {
Expand Down Expand Up @@ -398,7 +459,7 @@ trait States<T: PartialEq> {
/// Adds a state to this states collection.
fn add(&mut self, token: Token<T>, count: usize);
/// Gets the next state from this collection of states.
fn next(&self) -> Token<T>;
fn next<R: Rng>(&self, rng: &mut R) -> Token<T>;
}

impl<T> States<T> for HashMap<Token<T>, usize>
Expand All @@ -414,13 +475,17 @@ where
}
}

fn next(&self) -> Token<T> {
fn next<R>(&self, rng: &mut R) -> Token<T>
where
R: Rng,
{
let mut sum = 0;
for &value in self.values() {
sum += value;
}
let mut rng = thread_rng();

let cap = rng.gen_range(0..sum);

sum = 0;
for (key, &value) in self.iter() {
sum += value;
Expand All @@ -434,6 +499,9 @@ where

#[cfg(test)]
mod test {

use rand_chacha::{rand_core::SeedableRng, ChaCha12Rng};

use super::Chain;

#[test]
Expand Down Expand Up @@ -464,6 +532,27 @@ mod test {
assert!([vec![3, 5, 10], vec![3, 5, 12], vec![5, 10], vec![5, 12]].contains(&v));
}

#[test]
fn generate_with_seed() {
let mut chain = Chain::new();
chain.feed(vec![3u8, 5, 10]).feed(vec![5, 12]);
let v = chain.generate_with_seed(3);
assert!(v == vec![3, 5, 10]);
let v = chain.generate_with_seed(1);
assert!(v == vec![5, 10]);
}

#[test]
fn generate_with_rng() {
let mut rng = ChaCha12Rng::seed_from_u64(3);
let mut chain = Chain::new();
chain.feed(vec![3u8, 5, 10]).feed(vec![5, 12]);
let v = chain.generate_with_rng(&mut rng);
assert!(v == vec![3, 5, 10]);
let v = chain.generate_with_rng(&mut rng);
assert!(v == vec![3, 5, 12]);
}

#[test]
fn generate_for_higher_order() {
let mut chain = Chain::of_order(2);
Expand All @@ -486,6 +575,27 @@ mod test {
assert!([vec![5, 10], vec![5, 12]].contains(&v));
}

#[test]
fn generate_from_token_with_seed() {
let mut chain = Chain::new();
chain.feed(vec![3u8, 5, 10, 13]).feed(vec![5, 12, 10]);
let v = chain.generate_from_token_with_seed(5, 3);
assert!(v == vec![5, 10, 13]);
let v = chain.generate_from_token_with_seed(5, 1);
assert!(v == vec![5, 12, 10, 13]);
}

#[test]
fn generate_from_token_with_rng() {
let mut rng = ChaCha12Rng::seed_from_u64(3);
let mut chain = Chain::new();
chain.feed(vec![3u8, 5, 10, 13]).feed(vec![5, 12, 10]);
let v = chain.generate_from_token_with_rng(5, &mut rng);
assert!(v == vec![5, 10, 13]);
let v = chain.generate_from_token_with_rng(5, &mut rng);
assert!(v == vec![5, 10]);
}

#[test]
fn generate_from_unfound_token() {
let mut chain = Chain::new();
Expand Down