August 20th 2021, 2:23 am

This commit is contained in:
vandomej 2021-08-20 02:23:18 -07:00
parent 61786ab303
commit f6de0191cd
6 changed files with 276 additions and 86 deletions

View file

@ -1,5 +1,5 @@
//! A trait used to interact with the internal state of nodes within the [`Bracket`] //! A trait used to interact with the internal state of nodes within the [`Bracket`]
//! //!
//! [`Bracket`]: crate::bracket::Bracket //! [`Bracket`]: crate::bracket::Bracket
use super::genetic_state::GeneticState; use super::genetic_state::GeneticState;
@ -12,21 +12,21 @@ use std::fmt;
/// [`Bracket`]: crate::bracket::Bracket /// [`Bracket`]: crate::bracket::Bracket
pub trait GeneticNode { pub trait GeneticNode {
/// Initializes a new instance of a [`GeneticState`]. /// Initializes a new instance of a [`GeneticState`].
/// ///
/// # Examples /// # Examples
/// ///
/// ``` /// ```
/// # use gemla::bracket::genetic_node::GeneticNode; /// # use gemla::bracket::genetic_node::GeneticNode;
/// # /// #
/// struct Node { /// struct Node {
/// pub fit_score: f64, /// pub fit_score: f64,
/// } /// }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// fn initialize() -> Result<Box<Self>, String> { /// fn initialize() -> Result<Box<Self>, String> {
/// Ok(Box::new(Node {fit_score: 0.0})) /// Ok(Box::new(Node {fit_score: 0.0}))
/// } /// }
/// ///
/// //... /// //...
/// # /// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
@ -45,7 +45,7 @@ pub trait GeneticNode {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), String> { /// # fn main() -> Result<(), String> {
/// let node = Node::initialize()?; /// let node = Node::initialize()?;
/// assert_eq!(node.get_fit_score(), 0.0); /// assert_eq!(node.get_fit_score(), 0.0);
@ -66,62 +66,63 @@ pub trait GeneticNode {
/// pub fit_score: f64, /// pub fit_score: f64,
/// //... /// //...
/// } /// }
/// ///
/// struct Node { /// struct Node {
/// pub model: Vec<Model>, /// pub models: Vec<Model>,
/// //... /// //...
/// } /// }
/// ///
/// impl Model { /// impl Model {
/// fn fit(&mut self, epochs: u64) -> Result<(), String> { /// fn fit(&mut self, epochs: u64) -> Result<(), String> {
/// //... /// //...
/// # self.fit_score += epochs as f64; /// # self.fit_score += epochs as f64;
/// # Ok(()) /// # Ok(())
/// } /// }
/// } /// }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, String> { /// # fn initialize() -> Result<Box<Self>, String> {
/// # Ok(Box::new(Node {fit_score: 0.0, model: Model {fit_score: 0.0}})) /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # } /// # }
/// # /// #
/// //... /// //...
/// ///
/// fn simulate(&mut self, iterations: u64) -> Result<(), String> { /// fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// self.model.fit(iterations)?; /// for m in self.models.iter_mut()
/// self.fit_score = self.model.fit_score; /// {
/// Ok(()) /// m.fit(iterations)?;
/// } /// }
/// Ok(())
/// }
/// ///
/// //... /// //...
/// # ///
/// # fn get_fit_score(&self) -> f64 { /// # fn get_fit_score(&self) -> f64 {
/// # self.fit_score /// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// # } /// # }
/// # /// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// # /// #
/// # fn mutate(&mut self) -> Result<(), String> { /// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), String> { /// # fn main() -> Result<(), String> {
/// let mut node = Node::initialize()?; /// let mut node = Node::initialize()?;
/// (*node).simulate(5)?; /// node.simulate(5)?;
/// # assert_eq!(node.get_fit_score(), 5.0); /// assert_eq!(node.get_fit_score(), 5.0);
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
fn simulate(&mut self, iterations: u64) -> Result<(), String>; fn simulate(&mut self, iterations: u64) -> Result<(), String>;
/// Returns a fit score associated with the nodes performance. /// Returns a fit score associated with the nodes performance.
/// This will be used by a bracket in order to determine the most successful child. /// This will be used by a bracket in order to determine the most successful child.
/// ///
/// # Examples /// # Examples
///
/// ``` /// ```
/// # use gemla::bracket::genetic_node::GeneticNode; /// # use gemla::bracket::genetic_node::GeneticNode;
/// # /// #
@ -130,47 +131,242 @@ pub trait GeneticNode {
/// //... /// //...
/// } /// }
/// ///
/// impl GeneticNode for Model { /// struct Node {
/// # fn initialize() -> Result<Box<Self>, String> { /// pub models: Vec<Model>,
/// # Ok(Box::new(Model {fit_score: 0.0, model: Model {fit_score: 0.0}}))
/// # }
///
/// //... /// //...
///
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # self.model.fit(iterations)?;
/// # self.fit_score = self.model.fit_score;
/// # Ok(())
/// # }
/// #
/// #
/// fn get_fit_score(&self) -> f64 {
/// self.fit_score
/// }
///
/// //...
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), String> { /// # impl Model {
/// let mut model = Model::initialize()?; /// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
/// assert_eq!(node.get_fit_score(), 0.0); /// # //...
/// # Ok(()) /// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # } /// # }
/// ``` ///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, String> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// #
/// # //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # for m in self.models.iter_mut()
/// # {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// //...
///
/// fn get_fit_score(&self) -> f64 {
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// }
///
/// //...
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// }
///
/// # fn main() -> Result<(), String> {
/// let mut node = Node::initialize()?;
/// node.simulate(5)?;
/// assert_eq!(node.get_fit_score(), 5.0);
/// # Ok(())
/// # }
/// ```
fn get_fit_score(&self) -> f64; fn get_fit_score(&self) -> f64;
/// Used when scoring the nodes after simulating and should remove underperforming children. /// Used when scoring the nodes after simulating and should remove underperforming children.
///
/// # Examples
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// population_size: i64,
/// //...
/// }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, String> {
/// # Ok(Box::new(Node {
/// # models: vec![
/// # Model { fit_score: 0.0 },
/// # Model { fit_score: 1.0 },
/// # Model { fit_score: 2.0 },
/// # Model { fit_score: 3.0 },
/// # Model { fit_score: 4.0 },
/// # ],
/// # population_size: 5,
/// # }))
/// # }
/// #
/// # //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # for m in self.models.iter_mut() {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// //...
///
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// self.models.truncate(3);
/// Ok(())
/// }
///
/// //...
/// #
/// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// }
///
/// # fn main() -> Result<(), String> {
/// let mut node = Node::initialize()?;
/// assert_eq!(node.models.len(), 5);
///
/// node.simulate(5)?;
/// node.calculate_scores_and_trim()?;
/// assert_eq!(node.models.len(), 3);
///
/// # assert_eq!(node.get_fit_score(), 9.0);
/// # Ok(())
/// # }
/// ```
fn calculate_scores_and_trim(&mut self) -> Result<(), String>; fn calculate_scores_and_trim(&mut self) -> Result<(), String>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring. /// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use std::convert::TryInto;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// population_size: i64,
/// //...
/// }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// fn mutate_random_individuals(_models: &Vec<Model>) -> Model
/// {
/// //...
/// # Model {
/// # fit_score: 0.0
/// # }
/// }
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, String> {
/// # Ok(Box::new(Node {
/// # models: vec![
/// # Model { fit_score: 0.0 },
/// # Model { fit_score: 1.0 },
/// # Model { fit_score: 2.0 },
/// # Model { fit_score: 3.0 },
/// # Model { fit_score: 4.0 },
/// # ],
/// # population_size: 5,
/// # }))
/// # }
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # for m in self.models.iter_mut() {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// # self.models.truncate(3);
/// # Ok(())
/// # }
/// //...
///
/// fn mutate(&mut self) -> Result<(), String> {
/// loop {
/// if self.models.len() < self.population_size.try_into().unwrap()
/// {
/// self.models.push(mutate_random_individuals(&self.models))
/// }
/// else{
/// return Ok(());
/// }
/// }
/// }
/// }
///
/// # fn main() -> Result<(), String> {
/// let mut node = Node::initialize()?;
/// assert_eq!(node.models.len(), 5);
///
/// node.simulate(5)?;
/// node.calculate_scores_and_trim()?;
/// assert_eq!(node.models.len(), 3);
///
/// node.mutate()?;
/// assert_eq!(node.models.len(), 5);
///
/// # assert_eq!(node.get_fit_score(), 9.0);
/// # Ok(())
/// # }
/// ```
fn mutate(&mut self) -> Result<(), String>; fn mutate(&mut self) -> Result<(), String>;
} }
@ -191,8 +387,8 @@ where
T: GeneticNode + fmt::Debug, T: GeneticNode + fmt::Debug,
{ {
/// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to /// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to
/// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in /// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in
/// [`process_node`](#method.process_node). /// [`process_node`](#method.process_node).
/// ///
/// # Examples /// # Examples
/// ``` /// ```
@ -252,7 +448,7 @@ where
/// Will loop through the node training and scoring process for the given number of `iterations`. /// Will loop through the node training and scoring process for the given number of `iterations`.
/// ///
/// ## Transitions /// ## Transitions
/// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change
/// the state to `GeneticState::Simulate` /// the state to `GeneticState::Simulate`
/// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score`
/// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change /// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change

View file

@ -1,11 +1,11 @@
//! An enum used to control the state of a [`GeneticNode`] //! An enum used to control the state of a [`GeneticNode`]
//! //!
//! [`GeneticNode`]: crate::bracket::genetic_node //! [`GeneticNode`]: crate::bracket::genetic_node
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// An enum used to control the state of a [`GeneticNode`] /// An enum used to control the state of a [`GeneticNode`]
/// ///
/// [`GeneticNode`]: crate::bracket::genetic_node /// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[serde(tag = "enumType", content = "enumContent")] #[serde(tag = "enumType", content = "enumContent")]

View file

@ -124,10 +124,9 @@ where
} }
#[cfg(test)] #[cfg(test)]
mod tests mod tests {
{
use super::*; use super::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
@ -237,5 +236,4 @@ mod tests
std::fs::remove_file("./temp2").expect("Unable to remove file"); std::fs::remove_file("./temp2").expect("Unable to remove file");
} }
}
}

View file

@ -141,5 +141,4 @@ mod tests {
Ok(()) Ok(())
} }
} }

View file

@ -4,4 +4,4 @@ extern crate regex;
pub mod tree; pub mod tree;
pub mod bracket; pub mod bracket;
pub mod constants; pub mod constants;
pub mod file_linked; pub mod file_linked;

View file

@ -23,7 +23,6 @@
//! //# } //! //# }
//! ``` //! ```
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
@ -136,10 +135,8 @@ where
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests mod tests {
{
use super::*; use super::*;
#[test] #[test]
@ -176,4 +173,4 @@ mod tests
); );
assert_eq!(Tree::<i32>::fmt_node(&None), "_"); assert_eq!(Tree::<i32>::fmt_node(&None), "_");
} }
} }