diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index cfeb3da..4e6ed01 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -23,8 +23,7 @@ fn main() -> anyhow::Result<()> { let file_path = matches.value_of(gemla::constants::args::FILE).unwrap(); let mut gemla = Gemla::::new(&PathBuf::from(file_path), true)?; - gemla.simulate(1)?; - gemla.simulate(1)?; + gemla.simulate(3)?; gemla.simulate(1)?; Ok(()) diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index c95b25a..12846ca 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -1,7 +1,7 @@ use gemla::bracket::genetic_node::GeneticNode; use gemla::error; use rand::prelude::*; -use rand::{random, thread_rng}; +use rand::thread_rng; use serde::{Deserialize, Serialize}; use std::convert::TryInto; @@ -10,15 +10,15 @@ const POPULATION_REDUCTION_SIZE: u64 = 3; #[derive(Serialize, Deserialize, Debug)] pub struct TestState { - pub population: Vec, + pub population: Vec, } impl Default for TestState { fn default() -> Self { - let mut population: Vec = vec![]; + let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { - population.push(random::() as f64) + population.push(thread_rng().gen_range(0..10000)) } TestState { population } @@ -27,10 +27,10 @@ impl Default for TestState { impl GeneticNode for TestState { fn initialize() -> Result, error::Error> { - let mut population: Vec = vec![]; + let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { - population.push(random::() as f64) + population.push(thread_rng().gen_range(0..10000)) } Ok(Box::new(TestState { population })) @@ -44,26 +44,23 @@ impl GeneticNode for TestState { .population .clone() .iter() - .map(|p| p + rng.gen_range(-10.0..10.0)) + .map(|p| p + rng.gen_range(-10..10)) .collect() } Ok(()) } - fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { + fn mutate(&mut self) -> Result<(), error::Error> { + let mut rng = thread_rng(); + let mut v = self.population.clone(); - v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + v.sort(); v.reverse(); self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); - Ok(()) - } - - fn mutate(&mut self) -> Result<(), error::Error> { - let mut rng = thread_rng(); loop { if self.population.len() >= POPULATION_SIZE.try_into().unwrap() { break; @@ -83,7 +80,7 @@ impl GeneticNode for TestState { let mut new_individual = self.population.clone()[new_individual_index]; let cross_breed = self.population.clone()[cross_breed_index]; - new_individual += cross_breed + rng.gen_range(-10.0..10.0); + new_individual += cross_breed + rng.gen_range(-10..10); self.population.push(new_individual); } @@ -123,7 +120,7 @@ mod tests { #[test] fn test_simulate() { let mut state = TestState { - population: vec![1.0, 1.0, 2.0, 3.0], + population: vec![1, 1, 2, 3], }; let original_population = state.population.clone(); @@ -135,33 +132,19 @@ mod tests { assert!(original_population .iter() .zip(state.population.iter()) - .all(|(&a, &b)| b >= a - 10.0 && b <= a + 10.0)); + .all(|(&a, &b)| b >= a - 10 && b <= a + 10)); state.simulate(2).unwrap(); assert!(original_population .iter() .zip(state.population.iter()) - .all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0)) - } - - #[test] - fn test_calculate_scores_and_trim() { - let mut state = TestState { - population: vec![4.0, 1.0, 1.0, 3.0, 2.0], - }; - - state.calculate_scores_and_trim().unwrap(); - - assert_eq!(state.population.len(), POPULATION_REDUCTION_SIZE as usize); - assert!(state.population.iter().any(|&x| x == 4.0)); - assert!(state.population.iter().any(|&x| x == 3.0)); - assert!(state.population.iter().any(|&x| x == 2.0)); + .all(|(&a, &b)| b >= a - 30 && b <= a + 30)) } #[test] fn test_mutate() { let mut state = TestState { - population: vec![4.0, 3.0, 3.0], + population: vec![4, 3, 3], }; state.mutate().unwrap(); @@ -172,18 +155,18 @@ mod tests { #[test] fn test_merge() { let state1 = TestState { - population: vec![1.0, 2.0, 4.0, 5.0], + population: vec![1, 2, 4, 5], }; let state2 = TestState { - population: vec![0.0, 1.0, 3.0, 7.0], + population: vec![0, 1, 3, 7], }; let merged_state = TestState::merge(&state1, &state2).unwrap(); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); - assert!(merged_state.population.iter().any(|&x| x == 7.0)); - assert!(merged_state.population.iter().any(|&x| x == 5.0)); - assert!(merged_state.population.iter().any(|&x| x == 4.0)); + assert!(merged_state.population.iter().any(|&x| x == 7)); + assert!(merged_state.population.iter().any(|&x| x == 5)); + assert!(merged_state.population.iter().any(|&x| x == 4)); } } diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index e2ce60e..d669592 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -18,8 +18,6 @@ pub enum GeneticState { Initialize, /// The node is currently simulating a round against target data to determine the fitness of the population Simulate, - /// The node is currently selecting members of the population that scored well and reducing the total population size - Score, /// The node is currently mutating members of it's population and breeding new members Mutate, /// The node has finished processing for a given number of iterations @@ -33,322 +31,20 @@ pub trait GeneticNode { /// Initializes a new instance of a [`GeneticState`]. /// /// # Examples - /// - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::error::Error; - /// # - /// struct Node { - /// pub fit_score: f64, - /// } - /// - /// impl GeneticNode for Node { - /// fn initialize() -> Result, Error> { - /// Ok(Box::new(Node {fit_score: 0.0})) - /// } - /// - /// //... - /// # - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {fit_score: 0.0})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let node = Node::initialize()?; - /// assert_eq!(node.fit_score, 0.0); - /// # Ok(()) - /// # } - /// ``` + /// TODO fn initialize() -> Result, Error>; /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. /// /// # Examples - /// - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::error::Error; - /// # - /// struct Model { - /// pub fit_score: f64, - /// //... - /// } - /// - /// struct Node { - /// pub models: Vec, - /// //... - /// } - /// - /// impl Model { - /// fn fit(&mut self, epochs: u64) -> Result<(), Error> { - /// //... - /// # self.fit_score += epochs as f64; - /// # Ok(()) - /// } - /// } - /// - /// # impl Node { - /// # fn get_fit_score(&self) -> f64 { - /// # self.models - /// # .iter() - /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) - /// # .unwrap() - /// # .fit_score - /// # } - /// # } - /// # - /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) - /// # } - /// # - /// //... - /// - /// fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// for m in self.models.iter_mut() - /// { - /// m.fit(iterations)?; - /// } - /// Ok(()) - /// } - /// - /// //... - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let mut node = Node::initialize()?; - /// node.simulate(5)?; - /// assert_eq!(node.get_fit_score(), 5.0); - /// # Ok(()) - /// # } - /// ``` + /// TODO fn simulate(&mut self, iterations: u64) -> Result<(), Error>; - /// Used when scoring the nodes after simulating and should remove underperforming children. - /// - /// # Examples - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::error::Error; - /// # - /// struct Model { - /// pub fit_score: f64, - /// //... - /// } - /// - /// struct Node { - /// pub models: Vec, - /// population_size: i64, - /// //... - /// } - /// - /// # impl Model { - /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { - /// # //... - /// # self.fit_score += epochs as f64; - /// # Ok(()) - /// # } - /// # } - /// # - /// # - /// # impl Node { - /// # fn get_fit_score(&self) -> f64 { - /// # self.models - /// # .iter() - /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) - /// # .unwrap() - /// # .fit_score - /// # } - /// # } - /// # - /// - /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(Node { - /// # models: vec![ - /// # Model { fit_score: 0.0 }, - /// # Model { fit_score: 1.0 }, - /// # Model { fit_score: 2.0 }, - /// # Model { fit_score: 3.0 }, - /// # Model { fit_score: 4.0 }, - /// # ], - /// # population_size: 5, - /// # })) - /// # } - /// # - /// # //... - /// # - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # for m in self.models.iter_mut() { - /// # m.fit(iterations)?; - /// # } - /// # Ok(()) - /// # } - /// # - /// //... - /// - /// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); - /// self.models.truncate(3); - /// Ok(()) - /// } - /// - /// //... - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let mut node = Node::initialize()?; - /// assert_eq!(node.models.len(), 5); - /// - /// node.simulate(5)?; - /// node.calculate_scores_and_trim()?; - /// assert_eq!(node.models.len(), 3); - /// - /// # assert_eq!(node.get_fit_score(), 9.0); - /// # Ok(()) - /// # } - /// ``` - fn calculate_scores_and_trim(&mut self) -> Result<(), Error>; - /// Mutates members in a population and/or crossbreeds them to produce new offspring. /// /// # Examples - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::error::Error; - /// # use std::convert::TryInto; - /// # - /// struct Model { - /// pub fit_score: f64, - /// //... - /// } - /// - /// struct Node { - /// pub models: Vec, - /// population_size: i64, - /// //... - /// } - /// # - /// # impl Node { - /// # fn get_fit_score(&self) -> f64 { - /// # self.models - /// # .iter() - /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) - /// # .unwrap() - /// # .fit_score - /// # } - /// # } - /// - /// # impl Model { - /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { - /// # //... - /// # self.fit_score += epochs as f64; - /// # Ok(()) - /// # } - /// # } - /// - /// fn mutate_random_individuals(_models: &Vec) -> Model - /// { - /// //... - /// # Model { - /// # fit_score: 0.0 - /// # } - /// } - /// - /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(Node { - /// # models: vec![ - /// # Model { fit_score: 0.0 }, - /// # Model { fit_score: 1.0 }, - /// # Model { fit_score: 2.0 }, - /// # Model { fit_score: 3.0 }, - /// # Model { fit_score: 4.0 }, - /// # ], - /// # population_size: 5, - /// # })) - /// # } - /// # - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # for m in self.models.iter_mut() { - /// # m.fit(iterations)?; - /// # } - /// # Ok(()) - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); - /// # self.models.truncate(3); - /// # Ok(()) - /// # } - /// //... - /// - /// fn mutate(&mut self) -> Result<(), Error> { - /// loop { - /// if self.models.len() < self.population_size.try_into().unwrap() - /// { - /// self.models.push(mutate_random_individuals(&self.models)) - /// } - /// else{ - /// return Ok(()); - /// } - /// } - /// } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let mut node = Node::initialize()?; - /// assert_eq!(node.models.len(), 5); - /// - /// node.simulate(5)?; - /// node.calculate_scores_and_trim()?; - /// assert_eq!(node.models.len(), 3); - /// - /// node.mutate()?; - /// assert_eq!(node.models.len(), 5); - /// - /// # assert_eq!(node.get_fit_score(), 9.0); - /// # Ok(()) - /// # } - /// ``` + /// TODO fn mutate(&mut self) -> Result<(), Error>; fn merge(left: &Self, right: &Self) -> Result, Error>; @@ -375,46 +71,7 @@ where /// [`process_node`](#method.process_node). /// /// # Examples - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::bracket::genetic_node::GeneticNodeWrapper; - /// # use gemla::error::Error; - /// # #[derive(Debug)] - /// struct Node { - /// # pub fit_score: f64, - /// //... - /// } - /// - /// impl GeneticNode for Node { - /// //... - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(Node {fit_score: 0.0})) - /// # } - /// # - /// # - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {fit_score: 0.0})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let mut wrapped_node = GeneticNodeWrapper::::new()?; - /// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0); - /// # Ok(()) - /// # } - /// ``` + /// TODO pub fn new() -> Result { let mut node = GeneticNodeWrapper { data: None, @@ -448,14 +105,11 @@ where /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change /// the state to `GeneticState::Simulate` /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` - /// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change - /// state to `GeneticState::Finish`, otherwise it will change the state to `GeneticState::Mutate. /// - `GeneticState::Mutate`: Will call [`mutate`] and will change the state to `GeneticState::Simulate.` /// - `GeneticState::Finish`: Will finish processing the node and return. /// /// [`initialize`]: crate::bracket::genetic_node::GeneticNode#tymethod.initialize /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate - /// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> { // Looping through each state transition until the number of iterations have been reached. @@ -474,20 +128,12 @@ where .unwrap() .simulate(5) .with_context(|| format!("Error simulating node: {:?}", self))?; - self.state = GeneticState::Score; - } - (GeneticState::Score, Some(_)) => { - self.data - .as_mut() - .unwrap() - .calculate_scores_and_trim() - .with_context(|| format!("Error scoring and trimming node: {:?}", self))?; self.state = if self.iteration == iterations { GeneticState::Finish } else { GeneticState::Mutate - } + }; } (GeneticState::Mutate, Some(_)) => { self.data diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 038b7ca..0764f76 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -136,8 +136,6 @@ where Ok(Gemla { data: if overwrite { - FileLinked::from_file(path)? - } else { FileLinked::new( Bracket { tree: Some(btree!(None)), @@ -145,6 +143,8 @@ where }, path, )? + } else { + FileLinked::from_file(path)? }, }) } @@ -196,10 +196,6 @@ mod tests { Ok(()) } - fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - Ok(()) - } - fn mutate(&mut self) -> Result<(), Error> { Ok(()) }