diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index a452f7b..7eeb832 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -1,5 +1,5 @@ //! A trait used to interact with the internal state of nodes within the [`Bracket`] -//! +//! //! [`Bracket`]: crate::bracket::Bracket use super::genetic_state::GeneticState; @@ -12,21 +12,21 @@ use std::fmt; /// [`Bracket`]: crate::bracket::Bracket pub trait GeneticNode { /// Initializes a new instance of a [`GeneticState`]. - /// + /// /// # Examples - /// + /// /// ``` /// # use gemla::bracket::genetic_node::GeneticNode; /// # /// struct Node { /// pub fit_score: f64, /// } - /// + /// /// impl GeneticNode for Node { /// fn initialize() -> Result, String> { /// Ok(Box::new(Node {fit_score: 0.0})) /// } - /// + /// /// //... /// # /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { @@ -45,7 +45,7 @@ pub trait GeneticNode { /// # Ok(()) /// # } /// } - /// + /// /// # fn main() -> Result<(), String> { /// let node = Node::initialize()?; /// assert_eq!(node.get_fit_score(), 0.0); @@ -66,62 +66,63 @@ pub trait GeneticNode { /// pub fit_score: f64, /// //... /// } - /// + /// /// struct Node { - /// pub model: Vec, + /// pub models: Vec, /// //... /// } /// /// impl Model { /// fn fit(&mut self, epochs: u64) -> Result<(), String> { /// //... - /// # self.fit_score += epochs as f64; - /// # Ok(()) + /// # self.fit_score += epochs as f64; + /// # Ok(()) /// } /// } - /// + /// /// impl GeneticNode for Node { /// # fn initialize() -> Result, String> { - /// # Ok(Box::new(Node {fit_score: 0.0, model: Model {fit_score: 0.0}})) + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # } /// # /// //... - /// - /// fn simulate(&mut self, iterations: u64) -> Result<(), String> { - /// self.model.fit(iterations)?; - /// self.fit_score = self.model.fit_score; - /// Ok(()) - /// } + /// + /// fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// for m in self.models.iter_mut() + /// { + /// m.fit(iterations)?; + /// } + /// Ok(()) + /// } /// /// //... - /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.fit_score - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), String> { - /// # Ok(()) - /// # } + /// + /// # fn get_fit_score(&self) -> f64 { + /// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score + /// # } + /// # + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } /// } - /// + /// /// # fn main() -> Result<(), String> { /// let mut node = Node::initialize()?; - /// (*node).simulate(5)?; - /// # assert_eq!(node.get_fit_score(), 5.0); - /// # Ok(()) + /// node.simulate(5)?; + /// assert_eq!(node.get_fit_score(), 5.0); + /// # Ok(()) /// # } - /// ``` + /// ``` fn simulate(&mut self, iterations: u64) -> Result<(), String>; /// Returns a fit score associated with the nodes performance. /// This will be used by a bracket in order to determine the most successful child. - /// + /// /// # Examples - /// /// ``` /// # use gemla::bracket::genetic_node::GeneticNode; /// # @@ -130,47 +131,242 @@ pub trait GeneticNode { /// //... /// } /// - /// impl GeneticNode for Model { - /// # fn initialize() -> Result, String> { - /// # Ok(Box::new(Model {fit_score: 0.0, model: Model {fit_score: 0.0}})) - /// # } - /// + /// struct Node { + /// pub models: Vec, /// //... - /// - /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { - /// # self.model.fit(iterations)?; - /// # self.fit_score = self.model.fit_score; - /// # Ok(()) - /// # } - /// # - /// # - /// fn get_fit_score(&self) -> f64 { - /// self.fit_score - /// } - /// - /// //... - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), String> { - /// # Ok(()) - /// # } /// } - /// - /// # fn main() -> Result<(), String> { - /// let mut model = Model::initialize()?; - /// assert_eq!(node.get_fit_score(), 0.0); - /// # Ok(()) + /// + /// # impl Model { + /// # fn fit(&mut self, epochs: u64) -> Result<(), String> { + /// # //... + /// # self.fit_score += epochs as f64; + /// # Ok(()) + /// # } /// # } - /// ``` + /// + /// impl GeneticNode for Node { + /// # fn initialize() -> Result, String> { + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) + /// # } + /// # + /// # //... + /// # + /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// # for m in self.models.iter_mut() + /// # { + /// # m.fit(iterations)?; + /// # } + /// # Ok(()) + /// # } + /// # + /// //... + /// + /// fn get_fit_score(&self) -> f64 { + /// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score + /// } + /// + /// //... + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// } + /// + /// # fn main() -> Result<(), String> { + /// let mut node = Node::initialize()?; + /// node.simulate(5)?; + /// assert_eq!(node.get_fit_score(), 5.0); + /// # Ok(()) + /// # } + /// ``` fn get_fit_score(&self) -> f64; /// Used when scoring the nodes after simulating and should remove underperforming children. + /// + /// # Examples + /// ``` + /// # use gemla::bracket::genetic_node::GeneticNode; + /// # + /// struct Model { + /// pub fit_score: f64, + /// //... + /// } + /// + /// struct Node { + /// pub models: Vec, + /// population_size: i64, + /// //... + /// } + /// + /// # impl Model { + /// # fn fit(&mut self, epochs: u64) -> Result<(), String> { + /// # //... + /// # self.fit_score += epochs as f64; + /// # Ok(()) + /// # } + /// # } + /// + /// impl GeneticNode for Node { + /// # fn initialize() -> Result, String> { + /// # Ok(Box::new(Node { + /// # models: vec![ + /// # Model { fit_score: 0.0 }, + /// # Model { fit_score: 1.0 }, + /// # Model { fit_score: 2.0 }, + /// # Model { fit_score: 3.0 }, + /// # Model { fit_score: 4.0 }, + /// # ], + /// # population_size: 5, + /// # })) + /// # } + /// # + /// # //... + /// # + /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// # for m in self.models.iter_mut() { + /// # m.fit(iterations)?; + /// # } + /// # Ok(()) + /// # } + /// # + /// //... + /// + /// # fn get_fit_score(&self) -> f64 { + /// # self.models + /// # .iter() + /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) + /// # .unwrap() + /// # .fit_score + /// # } + /// # + /// fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); + /// self.models.truncate(3); + /// Ok(()) + /// } + /// + /// //... + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// } + /// + /// # fn main() -> Result<(), String> { + /// let mut node = Node::initialize()?; + /// assert_eq!(node.models.len(), 5); + /// + /// node.simulate(5)?; + /// node.calculate_scores_and_trim()?; + /// assert_eq!(node.models.len(), 3); + /// + /// # assert_eq!(node.get_fit_score(), 9.0); + /// # Ok(()) + /// # } + /// ``` fn calculate_scores_and_trim(&mut self) -> Result<(), String>; /// Mutates members in a population and/or crossbreeds them to produce new offspring. + /// + /// # Examples + /// ``` + /// # use gemla::bracket::genetic_node::GeneticNode; + /// # use std::convert::TryInto; + /// # + /// struct Model { + /// pub fit_score: f64, + /// //... + /// } + /// + /// struct Node { + /// pub models: Vec, + /// population_size: i64, + /// //... + /// } + /// + /// # impl Model { + /// # fn fit(&mut self, epochs: u64) -> Result<(), String> { + /// # //... + /// # self.fit_score += epochs as f64; + /// # Ok(()) + /// # } + /// # } + /// + /// fn mutate_random_individuals(_models: &Vec) -> Model + /// { + /// //... + /// # Model { + /// # fit_score: 0.0 + /// # } + /// } + /// + /// impl GeneticNode for Node { + /// # fn initialize() -> Result, String> { + /// # Ok(Box::new(Node { + /// # models: vec![ + /// # Model { fit_score: 0.0 }, + /// # Model { fit_score: 1.0 }, + /// # Model { fit_score: 2.0 }, + /// # Model { fit_score: 3.0 }, + /// # Model { fit_score: 4.0 }, + /// # ], + /// # population_size: 5, + /// # })) + /// # } + /// # + /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// # for m in self.models.iter_mut() { + /// # m.fit(iterations)?; + /// # } + /// # Ok(()) + /// # } + /// # + /// # fn get_fit_score(&self) -> f64 { + /// # self.models + /// # .iter() + /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) + /// # .unwrap() + /// # .fit_score + /// # } + /// # + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); + /// # self.models.truncate(3); + /// # Ok(()) + /// # } + /// //... + /// + /// fn mutate(&mut self) -> Result<(), String> { + /// loop { + /// if self.models.len() < self.population_size.try_into().unwrap() + /// { + /// self.models.push(mutate_random_individuals(&self.models)) + /// } + /// else{ + /// return Ok(()); + /// } + /// } + /// } + /// } + /// + /// # fn main() -> Result<(), String> { + /// let mut node = Node::initialize()?; + /// assert_eq!(node.models.len(), 5); + /// + /// node.simulate(5)?; + /// node.calculate_scores_and_trim()?; + /// assert_eq!(node.models.len(), 3); + /// + /// node.mutate()?; + /// assert_eq!(node.models.len(), 5); + /// + /// # assert_eq!(node.get_fit_score(), 9.0); + /// # Ok(()) + /// # } + /// ``` fn mutate(&mut self) -> Result<(), String>; } @@ -191,8 +387,8 @@ where T: GeneticNode + fmt::Debug, { /// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to - /// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in - /// [`process_node`](#method.process_node). + /// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in + /// [`process_node`](#method.process_node). /// /// # Examples /// ``` @@ -252,7 +448,7 @@ where /// Will loop through the node training and scoring process for the given number of `iterations`. /// /// ## Transitions - /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change + /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change /// the state to `GeneticState::Simulate` /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` /// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change diff --git a/gemla/src/bracket/genetic_state.rs b/gemla/src/bracket/genetic_state.rs index 7ca7bca..5630a35 100644 --- a/gemla/src/bracket/genetic_state.rs +++ b/gemla/src/bracket/genetic_state.rs @@ -1,11 +1,11 @@ //! An enum used to control the state of a [`GeneticNode`] -//! +//! //! [`GeneticNode`]: crate::bracket::genetic_node use serde::{Deserialize, Serialize}; /// An enum used to control the state of a [`GeneticNode`] -/// +/// /// [`GeneticNode`]: crate::bracket::genetic_node #[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[serde(tag = "enumType", content = "enumContent")] diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 6e87158..51c7a18 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -124,10 +124,9 @@ where } #[cfg(test)] -mod tests -{ +mod tests { use super::*; - + use serde::{Deserialize, Serialize}; use std::fmt; use std::str::FromStr; @@ -237,5 +236,4 @@ mod tests std::fs::remove_file("./temp2").expect("Unable to remove file"); } - -} \ No newline at end of file +} diff --git a/gemla/src/file_linked/mod.rs b/gemla/src/file_linked/mod.rs index a6cde54..e43182b 100644 --- a/gemla/src/file_linked/mod.rs +++ b/gemla/src/file_linked/mod.rs @@ -141,5 +141,4 @@ mod tests { Ok(()) } - } diff --git a/gemla/src/lib.rs b/gemla/src/lib.rs index 70a08ef..c11009a 100644 --- a/gemla/src/lib.rs +++ b/gemla/src/lib.rs @@ -4,4 +4,4 @@ extern crate regex; pub mod tree; pub mod bracket; pub mod constants; -pub mod file_linked; \ No newline at end of file +pub mod file_linked; diff --git a/gemla/src/tree/mod.rs b/gemla/src/tree/mod.rs index 13b9f6f..dc0b3db 100644 --- a/gemla/src/tree/mod.rs +++ b/gemla/src/tree/mod.rs @@ -23,7 +23,6 @@ //! //# } //! ``` - use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt; @@ -136,10 +135,8 @@ where } } - #[cfg(test)] -mod tests -{ +mod tests { use super::*; #[test] @@ -176,4 +173,4 @@ mod tests ); assert_eq!(Tree::::fmt_node(&None), "_"); } -} \ No newline at end of file +}