diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs new file mode 100644 index 0000000..4d334e2 --- /dev/null +++ b/gemla/src/bracket/genetic_node.rs @@ -0,0 +1,91 @@ +//! A trait used to interact with the internal state of nodes within the genetic bracket + +use super::genetic_state::GeneticState; + +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// A trait used to interact with the internal state of nodes within the genetic bracket +pub trait GeneticNode { + /// Runs a simulation on the state object in order to guage it's fitness. + /// - iterations: the number of iterations (learning cycles) that the current state should simulate + /// + /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. + fn simulate(&mut self, iterations: u64); + + /// Returns a fit score associated with the nodes performance. + /// This will be used by a bracket in order to determine the most successful child. + fn get_fit_score(&self) -> f64; + + fn calculate_scores_and_trim(&mut self); + + fn mutate(&mut self); + + /// Initializes a new instance of a genetic state. + fn initialize() -> Self; +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct GeneticNodeWrapper +where + T: GeneticNode, +{ + data: Option, + state: GeneticState, + iteration: u32, +} + +impl GeneticNodeWrapper +where + T: GeneticNode + fmt::Debug, +{ + fn new() -> Self { + let mut node = GeneticNodeWrapper { + data: None, + state: GeneticState::Initialize, + iteration: 0, + }; + + node.data = Some(T::initialize()); + node.state = GeneticState::Simulate; + + node + } + + fn process_node(&mut self, iterations: u32) -> Result<(), String> { + let mut result = Ok(()); + + loop { + match (self.state, self.data.as_ref()) { + (GeneticState::Initialize, _) => { + self.iteration = 0; + self.data = Some(T::initialize()); + self.state = GeneticState::Simulate; + } + (GeneticState::Simulate, Some(_)) => { + self.data.as_mut().unwrap().simulate(5); + self.state = GeneticState::Score; + } + (GeneticState::Score, Some(_)) => { + self.data.as_mut().unwrap().calculate_scores_and_trim(); + + self.state = if self.iteration == iterations { + GeneticState::Finish + } else { + GeneticState::Mutate + } + } + (GeneticState::Mutate, Some(_)) => { + self.data.as_mut().unwrap().mutate(); + self.state = GeneticState::Simulate; + } + (GeneticState::Finish, Some(_)) => { + break; + } + _ => result = Err(format!("Error processing node {:?}", self.data)), + } + } + + result + } +} diff --git a/gemla/src/bracket/genetic_state.rs b/gemla/src/bracket/genetic_state.rs index c86dd4e..1c5e295 100644 --- a/gemla/src/bracket/genetic_state.rs +++ b/gemla/src/bracket/genetic_state.rs @@ -1,17 +1,11 @@ -//! A trait used to interact with the internal state of nodes within the genetic bracket +use serde::{Deserialize, Serialize}; -/// A trait used to interact with the internal state of nodes within the genetic bracket -pub trait GeneticState { - /// Runs a simulation on the state object in order to guage it's fitness. - /// - iterations: the number of iterations (learning cycles) that the current state should simulate - /// - /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. - fn run_simulation(&mut self, iterations: u64); - - /// Returns a fit score associated with the nodes performance. - /// This will be used by a bracket in order to determine the most successful child. - fn get_fit_score(&self) -> f64; - - /// Initializes a new instance of a genetic state. - fn initialize() -> Self; +#[derive(Clone, Debug, Serialize, Deserialize, Copy)] +#[serde(tag = "enumType", content = "enumContent")] +pub enum GeneticState { + Initialize, + Simulate, + Score, + Mutate, + Finish, } diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 023c2ed..b27b953 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -1,3 +1,4 @@ +pub mod genetic_node; pub mod genetic_state; use super::file_linked::FileLinked; @@ -9,7 +10,7 @@ use std::fmt; use std::str::FromStr; use std::string::ToString; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[serde(tag = "enumType", content = "enumContent")] pub enum IterationScaling { Linear(u32), @@ -50,7 +51,7 @@ impl fmt::Display for Bracket { impl Bracket where - T: genetic_state::GeneticState + T: genetic_node::GeneticNode + ToString + FromStr + Default @@ -79,7 +80,7 @@ where if height == 1 { let mut base_node = btree!(T::initialize()); - base_node.val.run_simulation(match self.iteration_scaling { + base_node.val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => (x as u64) * height, }); @@ -93,7 +94,7 @@ where right.val.clone() }; - new_val.run_simulation(match self.iteration_scaling { + new_val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => (x as u64) * height, }); @@ -104,7 +105,7 @@ where pub fn run_simulation_step(&mut self) -> &mut Self { let new_branch = self.create_new_branch(self.step + 1); - self.tree.val.run_simulation(match self.iteration_scaling { + self.tree.val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => ((x as u64) * (self.step + 1)), }); diff --git a/gemla/src/tests/bracket.rs b/gemla/src/tests/bracket.rs index cfae5d5..92b9d71 100644 --- a/gemla/src/tests/bracket.rs +++ b/gemla/src/tests/bracket.rs @@ -30,8 +30,8 @@ impl TestState { } } -impl bracket::genetic_state::GeneticState for TestState { - fn run_simulation(&mut self, iterations: u64) { +impl bracket::genetic_node::GeneticNode for TestState { + fn simulate(&mut self, iterations: u64) { self.score += iterations as f64; } @@ -39,6 +39,10 @@ impl bracket::genetic_state::GeneticState for TestState { self.score } + fn calculate_scores_and_trim(&mut self) {} + + fn mutate(&mut self) {} + fn initialize() -> Self { TestState { score: 0.0 } } @@ -66,7 +70,6 @@ fn test_run() { bracket .mutate(|b| drop(b.iteration_scaling(bracket::IterationScaling::Linear(2)))) .expect("Failed to set iteration scaling"); - for _ in 0..3 { bracket .mutate(|b| drop(b.run_simulation_step())) diff --git a/gemla/src/tree/mod.rs b/gemla/src/tree/mod.rs index f5486f7..ed0ffb9 100644 --- a/gemla/src/tree/mod.rs +++ b/gemla/src/tree/mod.rs @@ -76,11 +76,7 @@ pub struct Tree { #[macro_export] macro_rules! btree { ($val:expr, $l:expr, $r:expr) => { - $crate::tree::Tree::new( - $val, - Some(Box::new($l)), - Some(Box::new($r)), - ) + $crate::tree::Tree::new($val, Some(Box::new($l)), Some(Box::new($r))) }; ($val:expr, , $r:expr) => { $crate::tree::Tree::new($val, None, Some(Box::new($r)))