August 20th 2021, 2:23 am
This commit is contained in:
parent
61786ab303
commit
f6de0191cd
6 changed files with 276 additions and 86 deletions
|
@ -1,5 +1,5 @@
|
|||
//! A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
//!
|
||||
//!
|
||||
//! [`Bracket`]: crate::bracket::Bracket
|
||||
|
||||
use super::genetic_state::GeneticState;
|
||||
|
@ -12,21 +12,21 @@ use std::fmt;
|
|||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
pub trait GeneticNode {
|
||||
/// Initializes a new instance of a [`GeneticState`].
|
||||
///
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
///
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
/// #
|
||||
/// struct Node {
|
||||
/// pub fit_score: f64,
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// fn initialize() -> Result<Box<Self>, String> {
|
||||
/// Ok(Box::new(Node {fit_score: 0.0}))
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// //...
|
||||
/// #
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
|
@ -45,7 +45,7 @@ pub trait GeneticNode {
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let node = Node::initialize()?;
|
||||
/// assert_eq!(node.get_fit_score(), 0.0);
|
||||
|
@ -66,62 +66,63 @@ pub trait GeneticNode {
|
|||
/// pub fit_score: f64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// struct Node {
|
||||
/// pub model: Vec<Model>,
|
||||
/// pub models: Vec<Model>,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// impl Model {
|
||||
/// fn fit(&mut self, epochs: u64) -> Result<(), String> {
|
||||
/// //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(Node {fit_score: 0.0, model: Model {fit_score: 0.0}}))
|
||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||
/// # }
|
||||
/// #
|
||||
/// //...
|
||||
///
|
||||
/// fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// self.model.fit(iterations)?;
|
||||
/// self.fit_score = self.model.fit_score;
|
||||
/// Ok(())
|
||||
/// }
|
||||
///
|
||||
/// fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// for m in self.models.iter_mut()
|
||||
/// {
|
||||
/// m.fit(iterations)?;
|
||||
/// }
|
||||
/// Ok(())
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
///
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let mut node = Node::initialize()?;
|
||||
/// (*node).simulate(5)?;
|
||||
/// # assert_eq!(node.get_fit_score(), 5.0);
|
||||
/// # Ok(())
|
||||
/// node.simulate(5)?;
|
||||
/// assert_eq!(node.get_fit_score(), 5.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
/// ```
|
||||
fn simulate(&mut self, iterations: u64) -> Result<(), String>;
|
||||
|
||||
/// Returns a fit score associated with the nodes performance.
|
||||
/// This will be used by a bracket in order to determine the most successful child.
|
||||
///
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
/// #
|
||||
|
@ -130,47 +131,242 @@ pub trait GeneticNode {
|
|||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// impl GeneticNode for Model {
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(Model {fit_score: 0.0, model: Model {fit_score: 0.0}}))
|
||||
/// # }
|
||||
///
|
||||
/// struct Node {
|
||||
/// pub models: Vec<Model>,
|
||||
/// //...
|
||||
///
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # self.model.fit(iterations)?;
|
||||
/// # self.fit_score = self.model.fit_score;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// #
|
||||
/// fn get_fit_score(&self) -> f64 {
|
||||
/// self.fit_score
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let mut model = Model::initialize()?;
|
||||
/// assert_eq!(node.get_fit_score(), 0.0);
|
||||
/// # Ok(())
|
||||
///
|
||||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
|
||||
/// # //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||
/// # }
|
||||
/// #
|
||||
/// # //...
|
||||
/// #
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # for m in self.models.iter_mut()
|
||||
/// # {
|
||||
/// # m.fit(iterations)?;
|
||||
/// # }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// //...
|
||||
///
|
||||
/// fn get_fit_score(&self) -> f64 {
|
||||
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let mut node = Node::initialize()?;
|
||||
/// node.simulate(5)?;
|
||||
/// assert_eq!(node.get_fit_score(), 5.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn get_fit_score(&self) -> f64;
|
||||
|
||||
/// Used when scoring the nodes after simulating and should remove underperforming children.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
/// #
|
||||
/// struct Model {
|
||||
/// pub fit_score: f64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// struct Node {
|
||||
/// pub models: Vec<Model>,
|
||||
/// population_size: i64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
|
||||
/// # //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// # }
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(Node {
|
||||
/// # models: vec![
|
||||
/// # Model { fit_score: 0.0 },
|
||||
/// # Model { fit_score: 1.0 },
|
||||
/// # Model { fit_score: 2.0 },
|
||||
/// # Model { fit_score: 3.0 },
|
||||
/// # Model { fit_score: 4.0 },
|
||||
/// # ],
|
||||
/// # population_size: 5,
|
||||
/// # }))
|
||||
/// # }
|
||||
/// #
|
||||
/// # //...
|
||||
/// #
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # for m in self.models.iter_mut() {
|
||||
/// # m.fit(iterations)?;
|
||||
/// # }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// //...
|
||||
///
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||
/// self.models.truncate(3);
|
||||
/// Ok(())
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let mut node = Node::initialize()?;
|
||||
/// assert_eq!(node.models.len(), 5);
|
||||
///
|
||||
/// node.simulate(5)?;
|
||||
/// node.calculate_scores_and_trim()?;
|
||||
/// assert_eq!(node.models.len(), 3);
|
||||
///
|
||||
/// # assert_eq!(node.get_fit_score(), 9.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn calculate_scores_and_trim(&mut self) -> Result<(), String>;
|
||||
|
||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
/// # use std::convert::TryInto;
|
||||
/// #
|
||||
/// struct Model {
|
||||
/// pub fit_score: f64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// struct Node {
|
||||
/// pub models: Vec<Model>,
|
||||
/// population_size: i64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), String> {
|
||||
/// # //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// # }
|
||||
///
|
||||
/// fn mutate_random_individuals(_models: &Vec<Model>) -> Model
|
||||
/// {
|
||||
/// //...
|
||||
/// # Model {
|
||||
/// # fit_score: 0.0
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(Node {
|
||||
/// # models: vec![
|
||||
/// # Model { fit_score: 0.0 },
|
||||
/// # Model { fit_score: 1.0 },
|
||||
/// # Model { fit_score: 2.0 },
|
||||
/// # Model { fit_score: 3.0 },
|
||||
/// # Model { fit_score: 4.0 },
|
||||
/// # ],
|
||||
/// # population_size: 5,
|
||||
/// # }))
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # for m in self.models.iter_mut() {
|
||||
/// # m.fit(iterations)?;
|
||||
/// # }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||
/// # self.models.truncate(3);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// //...
|
||||
///
|
||||
/// fn mutate(&mut self) -> Result<(), String> {
|
||||
/// loop {
|
||||
/// if self.models.len() < self.population_size.try_into().unwrap()
|
||||
/// {
|
||||
/// self.models.push(mutate_random_individuals(&self.models))
|
||||
/// }
|
||||
/// else{
|
||||
/// return Ok(());
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() -> Result<(), String> {
|
||||
/// let mut node = Node::initialize()?;
|
||||
/// assert_eq!(node.models.len(), 5);
|
||||
///
|
||||
/// node.simulate(5)?;
|
||||
/// node.calculate_scores_and_trim()?;
|
||||
/// assert_eq!(node.models.len(), 3);
|
||||
///
|
||||
/// node.mutate()?;
|
||||
/// assert_eq!(node.models.len(), 5);
|
||||
///
|
||||
/// # assert_eq!(node.get_fit_score(), 9.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn mutate(&mut self) -> Result<(), String>;
|
||||
}
|
||||
|
||||
|
@ -191,8 +387,8 @@ where
|
|||
T: GeneticNode + fmt::Debug,
|
||||
{
|
||||
/// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to
|
||||
/// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in
|
||||
/// [`process_node`](#method.process_node).
|
||||
/// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in
|
||||
/// [`process_node`](#method.process_node).
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
|
@ -252,7 +448,7 @@ where
|
|||
/// Will loop through the node training and scoring process for the given number of `iterations`.
|
||||
///
|
||||
/// ## Transitions
|
||||
/// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change
|
||||
/// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change
|
||||
/// the state to `GeneticState::Simulate`
|
||||
/// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score`
|
||||
/// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
//! An enum used to control the state of a [`GeneticNode`]
|
||||
//!
|
||||
//!
|
||||
//! [`GeneticNode`]: crate::bracket::genetic_node
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
///
|
||||
/// [`GeneticNode`]: crate::bracket::genetic_node
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
|
|
|
@ -124,10 +124,9 @@ where
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests
|
||||
{
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
@ -237,5 +236,4 @@ mod tests
|
|||
|
||||
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -141,5 +141,4 @@ mod tests {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,4 +4,4 @@ extern crate regex;
|
|||
pub mod tree;
|
||||
pub mod bracket;
|
||||
pub mod constants;
|
||||
pub mod file_linked;
|
||||
pub mod file_linked;
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
//! //# }
|
||||
//! ```
|
||||
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
@ -136,10 +135,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests
|
||||
{
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -176,4 +173,4 @@ mod tests
|
|||
);
|
||||
assert_eq!(Tree::<i32>::fmt_node(&None), "_");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue