From e47765095af166c0fbcab42f0102d8957aaf1254 Mon Sep 17 00:00:00 2001 From: vandomej Date: Mon, 23 Aug 2021 10:04:11 -0700 Subject: [PATCH] Progress commenting bracket file --- gemla/src/bracket/genetic_node.rs | 20 ++- gemla/src/bracket/genetic_state.rs | 23 ---- gemla/src/bracket/mod.rs | 192 ++++++++++++++++++++++++++--- gemla/src/tree/mod.rs | 17 +++ 4 files changed, 211 insertions(+), 41 deletions(-) delete mode 100644 gemla/src/bracket/genetic_state.rs diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index 7eeb832..9caa3fe 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -2,11 +2,27 @@ //! //! [`Bracket`]: crate::bracket::Bracket -use super::genetic_state::GeneticState; - use serde::{Deserialize, Serialize}; use std::fmt; +/// An enum used to control the state of a [`GeneticNode`] +/// +/// [`GeneticNode`]: crate::bracket::genetic_node +#[derive(Clone, Debug, Serialize, Deserialize, Copy)] +#[serde(tag = "enumType", content = "enumContent")] +pub enum GeneticState { + /// The node and it's data have not finished initializing + Initialize, + /// The node is currently simulating a round against target data to determine the fitness of the population + Simulate, + /// The node is currently selecting members of the population that scored well and reducing the total population size + Score, + /// The node is currently mutating members of it's population and breeding new members + Mutate, + /// The node has finished processing for a given number of iterations + Finish, +} + /// A trait used to interact with the internal state of nodes within the [`Bracket`] /// /// [`Bracket`]: crate::bracket::Bracket diff --git a/gemla/src/bracket/genetic_state.rs b/gemla/src/bracket/genetic_state.rs deleted file mode 100644 index 5630a35..0000000 --- a/gemla/src/bracket/genetic_state.rs +++ /dev/null @@ -1,23 +0,0 @@ -//! An enum used to control the state of a [`GeneticNode`] -//! -//! [`GeneticNode`]: crate::bracket::genetic_node - -use serde::{Deserialize, Serialize}; - -/// An enum used to control the state of a [`GeneticNode`] -/// -/// [`GeneticNode`]: crate::bracket::genetic_node -#[derive(Clone, Debug, Serialize, Deserialize, Copy)] -#[serde(tag = "enumType", content = "enumContent")] -pub enum GeneticState { - /// The node and it's data have not finished initializing - Initialize, - /// The node is currently simulating a round against target data to determine the fitness of the population - Simulate, - /// The node is currently selecting members of the population that scored well and reducing the total population size - Score, - /// The node is currently mutating members of it's population and breeding new members - Mutate, - /// The node has finished processing for a given number of iterations - Finish, -} diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 51c7a18..6180193 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -1,5 +1,7 @@ +//! Simulates a genetic algorithm on a population in order to improve the fit score and performance. The simulations +//! are performed in a tournament bracket configuration so that populations can compete against each other. + pub mod genetic_node; -pub mod genetic_state; use super::file_linked::FileLinked; use super::tree; @@ -10,15 +12,91 @@ use std::fmt; use std::str::FromStr; use std::string::ToString; +/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that +/// a node runs for. +/// +/// # Examples +/// +/// ``` +/// # use gemla::bracket::*; +/// # use serde::{Deserialize, Serialize}; +/// # use std::fmt; +/// # use std::str::FromStr; +/// # use std::string::ToString; +/// # +/// # #[derive(Default, Deserialize, Serialize, Clone)] +/// # struct TestState { +/// # pub score: f64, +/// # } +/// # +/// # impl FromStr for TestState { +/// # type Err = String; +/// # +/// # fn from_str(s: &str) -> Result { +/// # toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s)) +/// # } +/// # } +/// # +/// # impl fmt::Display for TestState { +/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +/// # write!(f, "{}", self.score) +/// # } +/// # } +/// # +/// # impl TestState { +/// # fn new(score: f64) -> TestState { +/// # TestState { score: score } +/// # } +/// # } +/// # +/// # impl genetic_node::GeneticNode for TestState { +/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { +/// # self.score += iterations as f64; +/// # Ok(()) +/// # } +/// # +/// # fn get_fit_score(&self) -> f64 { +/// # self.score +/// # } +/// # +/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { +/// # Ok(()) +/// # } +/// # +/// # fn mutate(&mut self) -> Result<(), String> { +/// # Ok(()) +/// # } +/// # +/// # fn initialize() -> Result, String> { +/// # Ok(Box::new(TestState { score: 0.0 })) +/// # } +/// # } +/// # +/// # fn main() { +/// let mut bracket = Bracket::::initialize("./temp".to_string()) +/// .expect("Bracket failed to initialize"); +/// +/// // Constant iteration scaling ensures that every node is simulated 5 times. +/// bracket +/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5)))) +/// .expect("Failed to set iteration scaling"); +/// +/// # std::fs::remove_file("./temp").expect("Unable to remove file"); +/// # } +/// ``` #[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[serde(tag = "enumType", content = "enumContent")] pub enum IterationScaling { - Linear(u32), + /// Scales the number of simulations linearly with the height of the bracket tree given by `f(x) = mx` where + /// x is the height and m is the linear constant provided. + Linear(u64), + /// Each node in a bracket is simulated the same number of times. + Constant(u64), } impl Default for IterationScaling { fn default() -> Self { - IterationScaling::Linear(1) + IterationScaling::Constant(1) } } @@ -32,14 +110,25 @@ impl fmt::Display for IterationScaling { } } +/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. +/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building +/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest +/// individuals. +/// +/// [`GeneticNode`]: genetic_node::GeneticNode #[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Bracket { +pub struct Bracket +where + T: genetic_node::GeneticNode, +{ tree: tree::Tree, - step: u64, iteration_scaling: IterationScaling, } -impl fmt::Display for Bracket { +impl fmt::Display for Bracket +where + T: genetic_node::GeneticNode, +{ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -60,11 +149,81 @@ where + Serialize + Clone, { + /// Initializes a bracket of type `T` storing the contents to `file_path` + /// + /// # Examples + /// ``` + /// # use gemla::bracket::*; + /// # use serde::{Deserialize, Serialize}; + /// # use std::fmt; + /// # use std::str::FromStr; + /// # use std::string::ToString; + /// # + /// #[derive(Default, Deserialize, Serialize, Clone)] + /// struct TestState { + /// pub score: f64, + /// } + /// + /// impl FromStr for TestState { + /// type Err = String; + /// + /// fn from_str(s: &str) -> Result { + /// toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s)) + /// } + /// } + /// # + /// # impl fmt::Display for TestState { + /// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + /// # write!(f, "{}", self.score) + /// # } + /// # } + /// # + /// impl TestState { + /// fn new(score: f64) -> TestState { + /// TestState { score: score } + /// } + /// } + /// + /// impl genetic_node::GeneticNode for TestState { + /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// # self.score += iterations as f64; + /// # Ok(()) + /// # } + /// # + /// # fn get_fit_score(&self) -> f64 { + /// # self.score + /// # } + /// # + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// fn initialize() -> Result, String> { + /// Ok(Box::new(TestState { score: 0.0 })) + /// } + /// } + /// + /// # fn main() { + /// let mut bracket = Bracket::::initialize("./temp".to_string()) + /// .expect("Bracket failed to initialize"); + /// + /// assert_eq!( + /// format!("{}", bracket), + /// format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}", + /// btree!(TestState::new(0.0))) + /// ); + /// + /// std::fs::remove_file("./temp").expect("Unable to remove file"); + /// # } + /// ``` pub fn initialize(file_path: String) -> Result, String> { FileLinked::new( Bracket { tree: btree!(*T::initialize()?), - step: 0, iteration_scaling: IterationScaling::default(), }, file_path, @@ -81,7 +240,8 @@ where let mut base_node = btree!(*T::initialize()?); base_node.val.simulate(match self.iteration_scaling { - IterationScaling::Linear(x) => (x as u64) * height, + IterationScaling::Linear(x) => x * height, + IterationScaling::Constant(x) => x, })?; Ok(btree!(base_node.val)) @@ -95,7 +255,8 @@ where }; new_val.simulate(match self.iteration_scaling { - IterationScaling::Linear(x) => (x as u64) * height, + IterationScaling::Linear(x) => x * height, + IterationScaling::Constant(x) => x, })?; Ok(btree!(new_val, left, right)) @@ -103,10 +264,11 @@ where } pub fn run_simulation_step(&mut self) -> Result<&mut Self, String> { - let new_branch = self.create_new_branch(self.step + 1)?; + let new_branch = self.create_new_branch(self.tree.height())?; self.tree.val.simulate(match self.iteration_scaling { - IterationScaling::Linear(x) => ((x as u64) * (self.step + 1)), + IterationScaling::Linear(x) => (x * self.tree.height()), + IterationScaling::Constant(x) => x, })?; let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() { @@ -117,8 +279,6 @@ where self.tree = btree!(new_val, new_branch, self.tree.clone()); - self.step += 1; - Ok(self) } } @@ -126,7 +286,7 @@ where #[cfg(test)] mod tests { use super::*; - + use serde::{Deserialize, Serialize}; use std::fmt; use std::str::FromStr; @@ -187,7 +347,7 @@ mod tests { assert_eq!( format!("{}", bracket), - format!("{{\"tree\":{},\"step\":0,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":1}}}}", + format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}", btree!(TestState::new(0.0))) ); @@ -210,7 +370,7 @@ mod tests { assert_eq!( format!("{}", bracket), - format!("{{\"tree\":{},\"step\":3,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}", + format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}", btree!( TestState::new(12.0), btree!( diff --git a/gemla/src/tree/mod.rs b/gemla/src/tree/mod.rs index dc0b3db..b6b5db3 100644 --- a/gemla/src/tree/mod.rs +++ b/gemla/src/tree/mod.rs @@ -25,6 +25,7 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use std::cmp::max; use std::fmt; use std::str::FromStr; @@ -102,6 +103,15 @@ impl Tree { Tree { val, left, right } } + pub fn height(&self) -> u64 { + match (self.left.as_ref(), self.right.as_ref()) { + (Some(l), Some(r)) => max(l.height(), r.height()) + 1, + (Some(l), None) => l.height() + 1, + (None, Some(r)) => r.height() + 1, + _ => 1, + } + } + pub fn fmt_node(t: &Option>>) -> String where T: fmt::Display, @@ -163,6 +173,13 @@ mod tests { ); } + #[test] + fn test_height() { + assert_eq!(1, btree!(1).height()); + + assert_eq!(3, btree!(1, btree!(2), btree!(2, btree!(3),)).height()); + } + #[test] fn test_fmt_node() { let t = btree!(17, btree!(16), btree!(12));