Revising changes for test state

This commit is contained in:
vandomej 2021-10-06 09:36:46 -07:00
parent 2fdf4c2545
commit 789f8feef3
4 changed files with 29 additions and 405 deletions

View file

@ -23,8 +23,7 @@ fn main() -> anyhow::Result<()> {
let file_path = matches.value_of(gemla::constants::args::FILE).unwrap(); let file_path = matches.value_of(gemla::constants::args::FILE).unwrap();
let mut gemla = Gemla::<TestState>::new(&PathBuf::from(file_path), true)?; let mut gemla = Gemla::<TestState>::new(&PathBuf::from(file_path), true)?;
gemla.simulate(1)?; gemla.simulate(3)?;
gemla.simulate(1)?;
gemla.simulate(1)?; gemla.simulate(1)?;
Ok(()) Ok(())

View file

@ -1,7 +1,7 @@
use gemla::bracket::genetic_node::GeneticNode; use gemla::bracket::genetic_node::GeneticNode;
use gemla::error; use gemla::error;
use rand::prelude::*; use rand::prelude::*;
use rand::{random, thread_rng}; use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::convert::TryInto; use std::convert::TryInto;
@ -10,15 +10,15 @@ const POPULATION_REDUCTION_SIZE: u64 = 3;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct TestState { pub struct TestState {
pub population: Vec<f64>, pub population: Vec<i64>,
} }
impl Default for TestState { impl Default for TestState {
fn default() -> Self { fn default() -> Self {
let mut population: Vec<f64> = vec![]; let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE { for _ in 0..POPULATION_SIZE {
population.push(random::<u64>() as f64) population.push(thread_rng().gen_range(0..10000))
} }
TestState { population } TestState { population }
@ -27,10 +27,10 @@ impl Default for TestState {
impl GeneticNode for TestState { impl GeneticNode for TestState {
fn initialize() -> Result<Box<Self>, error::Error> { fn initialize() -> Result<Box<Self>, error::Error> {
let mut population: Vec<f64> = vec![]; let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE { for _ in 0..POPULATION_SIZE {
population.push(random::<u64>() as f64) population.push(thread_rng().gen_range(0..10000))
} }
Ok(Box::new(TestState { population })) Ok(Box::new(TestState { population }))
@ -44,26 +44,23 @@ impl GeneticNode for TestState {
.population .population
.clone() .clone()
.iter() .iter()
.map(|p| p + rng.gen_range(-10.0..10.0)) .map(|p| p + rng.gen_range(-10..10))
.collect() .collect()
} }
Ok(()) Ok(())
} }
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { fn mutate(&mut self) -> Result<(), error::Error> {
let mut rng = thread_rng();
let mut v = self.population.clone(); let mut v = self.population.clone();
v.sort_by(|a, b| a.partial_cmp(b).unwrap()); v.sort();
v.reverse(); v.reverse();
self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
Ok(())
}
fn mutate(&mut self) -> Result<(), error::Error> {
let mut rng = thread_rng();
loop { loop {
if self.population.len() >= POPULATION_SIZE.try_into().unwrap() { if self.population.len() >= POPULATION_SIZE.try_into().unwrap() {
break; break;
@ -83,7 +80,7 @@ impl GeneticNode for TestState {
let mut new_individual = self.population.clone()[new_individual_index]; let mut new_individual = self.population.clone()[new_individual_index];
let cross_breed = self.population.clone()[cross_breed_index]; let cross_breed = self.population.clone()[cross_breed_index];
new_individual += cross_breed + rng.gen_range(-10.0..10.0); new_individual += cross_breed + rng.gen_range(-10..10);
self.population.push(new_individual); self.population.push(new_individual);
} }
@ -123,7 +120,7 @@ mod tests {
#[test] #[test]
fn test_simulate() { fn test_simulate() {
let mut state = TestState { let mut state = TestState {
population: vec![1.0, 1.0, 2.0, 3.0], population: vec![1, 1, 2, 3],
}; };
let original_population = state.population.clone(); let original_population = state.population.clone();
@ -135,33 +132,19 @@ mod tests {
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
.all(|(&a, &b)| b >= a - 10.0 && b <= a + 10.0)); .all(|(&a, &b)| b >= a - 10 && b <= a + 10));
state.simulate(2).unwrap(); state.simulate(2).unwrap();
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0)) .all(|(&a, &b)| b >= a - 30 && b <= a + 30))
}
#[test]
fn test_calculate_scores_and_trim() {
let mut state = TestState {
population: vec![4.0, 1.0, 1.0, 3.0, 2.0],
};
state.calculate_scores_and_trim().unwrap();
assert_eq!(state.population.len(), POPULATION_REDUCTION_SIZE as usize);
assert!(state.population.iter().any(|&x| x == 4.0));
assert!(state.population.iter().any(|&x| x == 3.0));
assert!(state.population.iter().any(|&x| x == 2.0));
} }
#[test] #[test]
fn test_mutate() { fn test_mutate() {
let mut state = TestState { let mut state = TestState {
population: vec![4.0, 3.0, 3.0], population: vec![4, 3, 3],
}; };
state.mutate().unwrap(); state.mutate().unwrap();
@ -172,18 +155,18 @@ mod tests {
#[test] #[test]
fn test_merge() { fn test_merge() {
let state1 = TestState { let state1 = TestState {
population: vec![1.0, 2.0, 4.0, 5.0], population: vec![1, 2, 4, 5],
}; };
let state2 = TestState { let state2 = TestState {
population: vec![0.0, 1.0, 3.0, 7.0], population: vec![0, 1, 3, 7],
}; };
let merged_state = TestState::merge(&state1, &state2).unwrap(); let merged_state = TestState::merge(&state1, &state2).unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7.0)); assert!(merged_state.population.iter().any(|&x| x == 7));
assert!(merged_state.population.iter().any(|&x| x == 5.0)); assert!(merged_state.population.iter().any(|&x| x == 5));
assert!(merged_state.population.iter().any(|&x| x == 4.0)); assert!(merged_state.population.iter().any(|&x| x == 4));
} }
} }

View file

@ -18,8 +18,6 @@ pub enum GeneticState {
Initialize, Initialize,
/// The node is currently simulating a round against target data to determine the fitness of the population /// The node is currently simulating a round against target data to determine the fitness of the population
Simulate, Simulate,
/// The node is currently selecting members of the population that scored well and reducing the total population size
Score,
/// The node is currently mutating members of it's population and breeding new members /// The node is currently mutating members of it's population and breeding new members
Mutate, Mutate,
/// The node has finished processing for a given number of iterations /// The node has finished processing for a given number of iterations
@ -33,322 +31,20 @@ pub trait GeneticNode {
/// Initializes a new instance of a [`GeneticState`]. /// Initializes a new instance of a [`GeneticState`].
/// ///
/// # Examples /// # Examples
/// /// TODO
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// #
/// struct Node {
/// pub fit_score: f64,
/// }
///
/// impl GeneticNode for Node {
/// fn initialize() -> Result<Box<Node>, Error> {
/// Ok(Box::new(Node {fit_score: 0.0}))
/// }
///
/// //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let node = Node::initialize()?;
/// assert_eq!(node.fit_score, 0.0);
/// # Ok(())
/// # }
/// ```
fn initialize() -> Result<Box<Self>, Error>; fn initialize() -> Result<Box<Self>, Error>;
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes. /// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
/// ///
/// # Examples /// # Examples
/// /// TODO
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// //...
/// }
///
/// impl Model {
/// fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// }
/// }
///
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// #
/// //...
///
/// fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// for m in self.models.iter_mut()
/// {
/// m.fit(iterations)?;
/// }
/// Ok(())
/// }
///
/// //...
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut node = Node::initialize()?;
/// node.simulate(5)?;
/// assert_eq!(node.get_fit_score(), 5.0);
/// # Ok(())
/// # }
/// ```
fn simulate(&mut self, iterations: u64) -> Result<(), Error>; fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
/// Used when scoring the nodes after simulating and should remove underperforming children.
///
/// # Examples
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// population_size: i64,
/// //...
/// }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
/// #
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {
/// # models: vec![
/// # Model { fit_score: 0.0 },
/// # Model { fit_score: 1.0 },
/// # Model { fit_score: 2.0 },
/// # Model { fit_score: 3.0 },
/// # Model { fit_score: 4.0 },
/// # ],
/// # population_size: 5,
/// # }))
/// # }
/// #
/// # //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # for m in self.models.iter_mut() {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// //...
///
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// self.models.truncate(3);
/// Ok(())
/// }
///
/// //...
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut node = Node::initialize()?;
/// assert_eq!(node.models.len(), 5);
///
/// node.simulate(5)?;
/// node.calculate_scores_and_trim()?;
/// assert_eq!(node.models.len(), 3);
///
/// # assert_eq!(node.get_fit_score(), 9.0);
/// # Ok(())
/// # }
/// ```
fn calculate_scores_and_trim(&mut self) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring. /// Mutates members in a population and/or crossbreeds them to produce new offspring.
/// ///
/// # Examples /// # Examples
/// ``` /// TODO
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// # use std::convert::TryInto;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// population_size: i64,
/// //...
/// }
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// fn mutate_random_individuals(_models: &Vec<Model>) -> Model
/// {
/// //...
/// # Model {
/// # fit_score: 0.0
/// # }
/// }
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {
/// # models: vec![
/// # Model { fit_score: 0.0 },
/// # Model { fit_score: 1.0 },
/// # Model { fit_score: 2.0 },
/// # Model { fit_score: 3.0 },
/// # Model { fit_score: 4.0 },
/// # ],
/// # population_size: 5,
/// # }))
/// # }
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # for m in self.models.iter_mut() {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// # self.models.truncate(3);
/// # Ok(())
/// # }
/// //...
///
/// fn mutate(&mut self) -> Result<(), Error> {
/// loop {
/// if self.models.len() < self.population_size.try_into().unwrap()
/// {
/// self.models.push(mutate_random_individuals(&self.models))
/// }
/// else{
/// return Ok(());
/// }
/// }
/// }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut node = Node::initialize()?;
/// assert_eq!(node.models.len(), 5);
///
/// node.simulate(5)?;
/// node.calculate_scores_and_trim()?;
/// assert_eq!(node.models.len(), 3);
///
/// node.mutate()?;
/// assert_eq!(node.models.len(), 5);
///
/// # assert_eq!(node.get_fit_score(), 9.0);
/// # Ok(())
/// # }
/// ```
fn mutate(&mut self) -> Result<(), Error>; fn mutate(&mut self) -> Result<(), Error>;
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>; fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
@ -375,46 +71,7 @@ where
/// [`process_node`](#method.process_node). /// [`process_node`](#method.process_node).
/// ///
/// # Examples /// # Examples
/// ``` /// TODO
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::bracket::genetic_node::GeneticNodeWrapper;
/// # use gemla::error::Error;
/// # #[derive(Debug)]
/// struct Node {
/// # pub fit_score: f64,
/// //...
/// }
///
/// impl GeneticNode for Node {
/// //...
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0}))
/// # }
/// #
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
/// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0);
/// # Ok(())
/// # }
/// ```
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let mut node = GeneticNodeWrapper { let mut node = GeneticNodeWrapper {
data: None, data: None,
@ -448,14 +105,11 @@ where
/// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change
/// the state to `GeneticState::Simulate` /// the state to `GeneticState::Simulate`
/// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score`
/// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change
/// state to `GeneticState::Finish`, otherwise it will change the state to `GeneticState::Mutate.
/// - `GeneticState::Mutate`: Will call [`mutate`] and will change the state to `GeneticState::Simulate.` /// - `GeneticState::Mutate`: Will call [`mutate`] and will change the state to `GeneticState::Simulate.`
/// - `GeneticState::Finish`: Will finish processing the node and return. /// - `GeneticState::Finish`: Will finish processing the node and return.
/// ///
/// [`initialize`]: crate::bracket::genetic_node::GeneticNode#tymethod.initialize /// [`initialize`]: crate::bracket::genetic_node::GeneticNode#tymethod.initialize
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> { pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> {
// Looping through each state transition until the number of iterations have been reached. // Looping through each state transition until the number of iterations have been reached.
@ -474,20 +128,12 @@ where
.unwrap() .unwrap()
.simulate(5) .simulate(5)
.with_context(|| format!("Error simulating node: {:?}", self))?; .with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = GeneticState::Score;
}
(GeneticState::Score, Some(_)) => {
self.data
.as_mut()
.unwrap()
.calculate_scores_and_trim()
.with_context(|| format!("Error scoring and trimming node: {:?}", self))?;
self.state = if self.iteration == iterations { self.state = if self.iteration == iterations {
GeneticState::Finish GeneticState::Finish
} else { } else {
GeneticState::Mutate GeneticState::Mutate
} };
} }
(GeneticState::Mutate, Some(_)) => { (GeneticState::Mutate, Some(_)) => {
self.data self.data

View file

@ -136,8 +136,6 @@ where
Ok(Gemla { Ok(Gemla {
data: if overwrite { data: if overwrite {
FileLinked::from_file(path)?
} else {
FileLinked::new( FileLinked::new(
Bracket { Bracket {
tree: Some(btree!(None)), tree: Some(btree!(None)),
@ -145,6 +143,8 @@ where
}, },
path, path,
)? )?
} else {
FileLinked::from_file(path)?
}, },
}) })
} }
@ -196,10 +196,6 @@ mod tests {
Ok(()) Ok(())
} }
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> { fn mutate(&mut self) -> Result<(), Error> {
Ok(()) Ok(())
} }