Progress commenting bracket file

This commit is contained in:
vandomej 2021-08-23 10:04:11 -07:00
parent f6de0191cd
commit e47765095a
4 changed files with 211 additions and 41 deletions

View file

@ -2,11 +2,27 @@
//! //!
//! [`Bracket`]: crate::bracket::Bracket //! [`Bracket`]: crate::bracket::Bracket
use super::genetic_state::GeneticState;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
/// An enum used to control the state of a [`GeneticNode`]
///
/// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[serde(tag = "enumType", content = "enumContent")]
pub enum GeneticState {
/// The node and it's data have not finished initializing
Initialize,
/// The node is currently simulating a round against target data to determine the fitness of the population
Simulate,
/// The node is currently selecting members of the population that scored well and reducing the total population size
Score,
/// The node is currently mutating members of it's population and breeding new members
Mutate,
/// The node has finished processing for a given number of iterations
Finish,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`] /// A trait used to interact with the internal state of nodes within the [`Bracket`]
/// ///
/// [`Bracket`]: crate::bracket::Bracket /// [`Bracket`]: crate::bracket::Bracket

View file

@ -1,23 +0,0 @@
//! An enum used to control the state of a [`GeneticNode`]
//!
//! [`GeneticNode`]: crate::bracket::genetic_node
use serde::{Deserialize, Serialize};
/// An enum used to control the state of a [`GeneticNode`]
///
/// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[serde(tag = "enumType", content = "enumContent")]
pub enum GeneticState {
/// The node and it's data have not finished initializing
Initialize,
/// The node is currently simulating a round against target data to determine the fitness of the population
Simulate,
/// The node is currently selecting members of the population that scored well and reducing the total population size
Score,
/// The node is currently mutating members of it's population and breeding new members
Mutate,
/// The node has finished processing for a given number of iterations
Finish,
}

View file

@ -1,5 +1,7 @@
//! Simulates a genetic algorithm on a population in order to improve the fit score and performance. The simulations
//! are performed in a tournament bracket configuration so that populations can compete against each other.
pub mod genetic_node; pub mod genetic_node;
pub mod genetic_state;
use super::file_linked::FileLinked; use super::file_linked::FileLinked;
use super::tree; use super::tree;
@ -10,15 +12,91 @@ use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use std::string::ToString; use std::string::ToString;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
/// a node runs for.
///
/// # Examples
///
/// ```
/// # use gemla::bracket::*;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl FromStr for TestState {
/// # type Err = String;
/// #
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> {
/// # toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
/// # }
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, String> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize("./temp".to_string())
/// .expect("Bracket failed to initialize");
///
/// // Constant iteration scaling ensures that every node is simulated 5 times.
/// bracket
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
/// .expect("Failed to set iteration scaling");
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
#[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[serde(tag = "enumType", content = "enumContent")] #[serde(tag = "enumType", content = "enumContent")]
pub enum IterationScaling { pub enum IterationScaling {
Linear(u32), /// Scales the number of simulations linearly with the height of the bracket tree given by `f(x) = mx` where
/// x is the height and m is the linear constant provided.
Linear(u64),
/// Each node in a bracket is simulated the same number of times.
Constant(u64),
} }
impl Default for IterationScaling { impl Default for IterationScaling {
fn default() -> Self { fn default() -> Self {
IterationScaling::Linear(1) IterationScaling::Constant(1)
} }
} }
@ -32,14 +110,25 @@ impl fmt::Display for IterationScaling {
} }
} }
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Bracket<T> { pub struct Bracket<T>
where
T: genetic_node::GeneticNode,
{
tree: tree::Tree<T>, tree: tree::Tree<T>,
step: u64,
iteration_scaling: IterationScaling, iteration_scaling: IterationScaling,
} }
impl<T: fmt::Display + Serialize> fmt::Display for Bracket<T> { impl<T: fmt::Display + Serialize> fmt::Display for Bracket<T>
where
T: genetic_node::GeneticNode,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,
@ -60,11 +149,81 @@ where
+ Serialize + Serialize
+ Clone, + Clone,
{ {
/// Initializes a bracket of type `T` storing the contents to `file_path`
///
/// # Examples
/// ```
/// # use gemla::bracket::*;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// #
/// #[derive(Default, Deserialize, Serialize, Clone)]
/// struct TestState {
/// pub score: f64,
/// }
///
/// impl FromStr for TestState {
/// type Err = String;
///
/// fn from_str(s: &str) -> Result<TestState, Self::Err> {
/// toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
/// }
/// }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// impl TestState {
/// fn new(score: f64) -> TestState {
/// TestState { score: score }
/// }
/// }
///
/// impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), String> {
/// # Ok(())
/// # }
/// #
/// fn initialize() -> Result<Box<Self>, String> {
/// Ok(Box::new(TestState { score: 0.0 }))
/// }
/// }
///
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize("./temp".to_string())
/// .expect("Bracket failed to initialize");
///
/// assert_eq!(
/// format!("{}", bracket),
/// format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}",
/// btree!(TestState::new(0.0)))
/// );
///
/// std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn initialize(file_path: String) -> Result<FileLinked<Self>, String> { pub fn initialize(file_path: String) -> Result<FileLinked<Self>, String> {
FileLinked::new( FileLinked::new(
Bracket { Bracket {
tree: btree!(*T::initialize()?), tree: btree!(*T::initialize()?),
step: 0,
iteration_scaling: IterationScaling::default(), iteration_scaling: IterationScaling::default(),
}, },
file_path, file_path,
@ -81,7 +240,8 @@ where
let mut base_node = btree!(*T::initialize()?); let mut base_node = btree!(*T::initialize()?);
base_node.val.simulate(match self.iteration_scaling { base_node.val.simulate(match self.iteration_scaling {
IterationScaling::Linear(x) => (x as u64) * height, IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?; })?;
Ok(btree!(base_node.val)) Ok(btree!(base_node.val))
@ -95,7 +255,8 @@ where
}; };
new_val.simulate(match self.iteration_scaling { new_val.simulate(match self.iteration_scaling {
IterationScaling::Linear(x) => (x as u64) * height, IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?; })?;
Ok(btree!(new_val, left, right)) Ok(btree!(new_val, left, right))
@ -103,10 +264,11 @@ where
} }
pub fn run_simulation_step(&mut self) -> Result<&mut Self, String> { pub fn run_simulation_step(&mut self) -> Result<&mut Self, String> {
let new_branch = self.create_new_branch(self.step + 1)?; let new_branch = self.create_new_branch(self.tree.height())?;
self.tree.val.simulate(match self.iteration_scaling { self.tree.val.simulate(match self.iteration_scaling {
IterationScaling::Linear(x) => ((x as u64) * (self.step + 1)), IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x,
})?; })?;
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() { let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() {
@ -117,8 +279,6 @@ where
self.tree = btree!(new_val, new_branch, self.tree.clone()); self.tree = btree!(new_val, new_branch, self.tree.clone());
self.step += 1;
Ok(self) Ok(self)
} }
} }
@ -126,7 +286,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
@ -187,7 +347,7 @@ mod tests {
assert_eq!( assert_eq!(
format!("{}", bracket), format!("{}", bracket),
format!("{{\"tree\":{},\"step\":0,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":1}}}}", format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}",
btree!(TestState::new(0.0))) btree!(TestState::new(0.0)))
); );
@ -210,7 +370,7 @@ mod tests {
assert_eq!( assert_eq!(
format!("{}", bracket), format!("{}", bracket),
format!("{{\"tree\":{},\"step\":3,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}", format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}",
btree!( btree!(
TestState::new(12.0), TestState::new(12.0),
btree!( btree!(

View file

@ -25,6 +25,7 @@
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp::max;
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
@ -102,6 +103,15 @@ impl<T> Tree<T> {
Tree { val, left, right } Tree { val, left, right }
} }
pub fn height(&self) -> u64 {
match (self.left.as_ref(), self.right.as_ref()) {
(Some(l), Some(r)) => max(l.height(), r.height()) + 1,
(Some(l), None) => l.height() + 1,
(None, Some(r)) => r.height() + 1,
_ => 1,
}
}
pub fn fmt_node(t: &Option<Box<Tree<T>>>) -> String pub fn fmt_node(t: &Option<Box<Tree<T>>>) -> String
where where
T: fmt::Display, T: fmt::Display,
@ -163,6 +173,13 @@ mod tests {
); );
} }
#[test]
fn test_height() {
assert_eq!(1, btree!(1).height());
assert_eq!(3, btree!(1, btree!(2), btree!(2, btree!(3),)).height());
}
#[test] #[test]
fn test_fmt_node() { fn test_fmt_node() {
let t = btree!(17, btree!(16), btree!(12)); let t = btree!(17, btree!(16), btree!(12));