Progress commenting bracket file
This commit is contained in:
parent
f6de0191cd
commit
e47765095a
4 changed files with 211 additions and 41 deletions
|
@ -2,11 +2,27 @@
|
|||
//!
|
||||
//! [`Bracket`]: crate::bracket::Bracket
|
||||
|
||||
use super::genetic_state::GeneticState;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
/// [`GeneticNode`]: crate::bracket::genetic_node
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
pub enum GeneticState {
|
||||
/// The node and it's data have not finished initializing
|
||||
Initialize,
|
||||
/// The node is currently simulating a round against target data to determine the fitness of the population
|
||||
Simulate,
|
||||
/// The node is currently selecting members of the population that scored well and reducing the total population size
|
||||
Score,
|
||||
/// The node is currently mutating members of it's population and breeding new members
|
||||
Mutate,
|
||||
/// The node has finished processing for a given number of iterations
|
||||
Finish,
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
//! An enum used to control the state of a [`GeneticNode`]
|
||||
//!
|
||||
//! [`GeneticNode`]: crate::bracket::genetic_node
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
/// [`GeneticNode`]: crate::bracket::genetic_node
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
pub enum GeneticState {
|
||||
/// The node and it's data have not finished initializing
|
||||
Initialize,
|
||||
/// The node is currently simulating a round against target data to determine the fitness of the population
|
||||
Simulate,
|
||||
/// The node is currently selecting members of the population that scored well and reducing the total population size
|
||||
Score,
|
||||
/// The node is currently mutating members of it's population and breeding new members
|
||||
Mutate,
|
||||
/// The node has finished processing for a given number of iterations
|
||||
Finish,
|
||||
}
|
|
@ -1,5 +1,7 @@
|
|||
//! Simulates a genetic algorithm on a population in order to improve the fit score and performance. The simulations
|
||||
//! are performed in a tournament bracket configuration so that populations can compete against each other.
|
||||
|
||||
pub mod genetic_node;
|
||||
pub mod genetic_state;
|
||||
|
||||
use super::file_linked::FileLinked;
|
||||
use super::tree;
|
||||
|
@ -10,15 +12,91 @@ use std::fmt;
|
|||
use std::str::FromStr;
|
||||
use std::string::ToString;
|
||||
|
||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||
/// a node runs for.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl FromStr for TestState {
|
||||
/// # type Err = String;
|
||||
/// #
|
||||
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> {
|
||||
/// # toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl fmt::Display for TestState {
|
||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
/// # write!(f, "{}", self.score)
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl TestState {
|
||||
/// # fn new(score: f64) -> TestState {
|
||||
/// # TestState { score: score }
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn initialize() -> Result<Box<Self>, String> {
|
||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize("./temp".to_string())
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// // Constant iteration scaling ensures that every node is simulated 5 times.
|
||||
/// bracket
|
||||
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
|
||||
/// .expect("Failed to set iteration scaling");
|
||||
///
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
pub enum IterationScaling {
|
||||
Linear(u32),
|
||||
/// Scales the number of simulations linearly with the height of the bracket tree given by `f(x) = mx` where
|
||||
/// x is the height and m is the linear constant provided.
|
||||
Linear(u64),
|
||||
/// Each node in a bracket is simulated the same number of times.
|
||||
Constant(u64),
|
||||
}
|
||||
|
||||
impl Default for IterationScaling {
|
||||
fn default() -> Self {
|
||||
IterationScaling::Linear(1)
|
||||
IterationScaling::Constant(1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,14 +110,25 @@ impl fmt::Display for IterationScaling {
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
|
||||
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
|
||||
/// individuals.
|
||||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Bracket<T> {
|
||||
pub struct Bracket<T>
|
||||
where
|
||||
T: genetic_node::GeneticNode,
|
||||
{
|
||||
tree: tree::Tree<T>,
|
||||
step: u64,
|
||||
iteration_scaling: IterationScaling,
|
||||
}
|
||||
|
||||
impl<T: fmt::Display + Serialize> fmt::Display for Bracket<T> {
|
||||
impl<T: fmt::Display + Serialize> fmt::Display for Bracket<T>
|
||||
where
|
||||
T: genetic_node::GeneticNode,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
|
@ -60,11 +149,81 @@ where
|
|||
+ Serialize
|
||||
+ Clone,
|
||||
{
|
||||
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// #
|
||||
/// #[derive(Default, Deserialize, Serialize, Clone)]
|
||||
/// struct TestState {
|
||||
/// pub score: f64,
|
||||
/// }
|
||||
///
|
||||
/// impl FromStr for TestState {
|
||||
/// type Err = String;
|
||||
///
|
||||
/// fn from_str(s: &str) -> Result<TestState, Self::Err> {
|
||||
/// toml::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
|
||||
/// }
|
||||
/// }
|
||||
/// #
|
||||
/// # impl fmt::Display for TestState {
|
||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
/// # write!(f, "{}", self.score)
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// impl TestState {
|
||||
/// fn new(score: f64) -> TestState {
|
||||
/// TestState { score: score }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), String> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), String> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// fn initialize() -> Result<Box<Self>, String> {
|
||||
/// Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize("./temp".to_string())
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// assert_eq!(
|
||||
/// format!("{}", bracket),
|
||||
/// format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}",
|
||||
/// btree!(TestState::new(0.0)))
|
||||
/// );
|
||||
///
|
||||
/// std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn initialize(file_path: String) -> Result<FileLinked<Self>, String> {
|
||||
FileLinked::new(
|
||||
Bracket {
|
||||
tree: btree!(*T::initialize()?),
|
||||
step: 0,
|
||||
iteration_scaling: IterationScaling::default(),
|
||||
},
|
||||
file_path,
|
||||
|
@ -81,7 +240,8 @@ where
|
|||
let mut base_node = btree!(*T::initialize()?);
|
||||
|
||||
base_node.val.simulate(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => (x as u64) * height,
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(base_node.val))
|
||||
|
@ -95,7 +255,8 @@ where
|
|||
};
|
||||
|
||||
new_val.simulate(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => (x as u64) * height,
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(new_val, left, right))
|
||||
|
@ -103,10 +264,11 @@ where
|
|||
}
|
||||
|
||||
pub fn run_simulation_step(&mut self) -> Result<&mut Self, String> {
|
||||
let new_branch = self.create_new_branch(self.step + 1)?;
|
||||
let new_branch = self.create_new_branch(self.tree.height())?;
|
||||
|
||||
self.tree.val.simulate(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => ((x as u64) * (self.step + 1)),
|
||||
IterationScaling::Linear(x) => (x * self.tree.height()),
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() {
|
||||
|
@ -117,8 +279,6 @@ where
|
|||
|
||||
self.tree = btree!(new_val, new_branch, self.tree.clone());
|
||||
|
||||
self.step += 1;
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
@ -126,7 +286,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
@ -187,7 +347,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
format!("{}", bracket),
|
||||
format!("{{\"tree\":{},\"step\":0,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":1}}}}",
|
||||
format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Constant\",\"enumContent\":1}}}}",
|
||||
btree!(TestState::new(0.0)))
|
||||
);
|
||||
|
||||
|
@ -210,7 +370,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
format!("{}", bracket),
|
||||
format!("{{\"tree\":{},\"step\":3,\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}",
|
||||
format!("{{\"tree\":{},\"iteration_scaling\":{{\"enumType\":\"Linear\",\"enumContent\":2}}}}",
|
||||
btree!(
|
||||
TestState::new(12.0),
|
||||
btree!(
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::max;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
|
@ -102,6 +103,15 @@ impl<T> Tree<T> {
|
|||
Tree { val, left, right }
|
||||
}
|
||||
|
||||
pub fn height(&self) -> u64 {
|
||||
match (self.left.as_ref(), self.right.as_ref()) {
|
||||
(Some(l), Some(r)) => max(l.height(), r.height()) + 1,
|
||||
(Some(l), None) => l.height() + 1,
|
||||
(None, Some(r)) => r.height() + 1,
|
||||
_ => 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fmt_node(t: &Option<Box<Tree<T>>>) -> String
|
||||
where
|
||||
T: fmt::Display,
|
||||
|
@ -163,6 +173,13 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_height() {
|
||||
assert_eq!(1, btree!(1).height());
|
||||
|
||||
assert_eq!(3, btree!(1, btree!(2), btree!(2, btree!(3),)).height());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fmt_node() {
|
||||
let t = btree!(17, btree!(16), btree!(12));
|
||||
|
|
Loading…
Add table
Reference in a new issue