diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index dcc10d9..3c9e4b4 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -11,7 +11,7 @@ use std::fmt; /// An enum used to control the state of a [`GeneticNode`] /// /// [`GeneticNode`]: crate::bracket::genetic_node -#[derive(Clone, Debug, Serialize, Deserialize, Copy)] +#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq)] #[serde(tag = "enumType", content = "enumContent")] pub enum GeneticState { /// The node and it's data have not finished initializing @@ -396,14 +396,14 @@ pub trait GeneticNode { /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// well as signal recovery. Transition states are given by [`GeneticState`] -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct GeneticNodeWrapper where T: GeneticNode, { pub data: Option, state: GeneticState, - pub iteration: u32, + pub iteration: u64, } impl GeneticNodeWrapper @@ -485,7 +485,7 @@ where /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate /// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate - pub fn process_node(&mut self, iterations: u32) -> Result<(), Error> { + pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> { // Looping through each state transition until the number of iterations have been reached. loop { match (self.state, &self.data) { @@ -523,6 +523,8 @@ where .unwrap() .mutate() .with_context(|| format!("Error mutating node: {:?}", self))?; + + self.iteration += 1; self.state = GeneticState::Simulate; } (GeneticState::Finish, Some(_)) => { diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index b896b80..29804ca 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -5,10 +5,12 @@ pub mod genetic_node; use crate::error::Error; use crate::tree; +use genetic_node::GeneticNodeWrapper; use file_linked::FileLinked; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::path; /// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that @@ -25,7 +27,7 @@ use std::path; /// # use std::string::ToString; /// # use std::path; /// # -/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] +/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] /// # struct TestState { /// # pub score: f64, /// # } @@ -98,13 +100,19 @@ pub struct Bracket where T: genetic_node::GeneticNode + Serialize, { - pub tree: tree::Tree, + pub tree: tree::Tree>>, iteration_scaling: IterationScaling, } impl Bracket where - T: genetic_node::GeneticNode + Default + DeserializeOwned + Serialize + Clone + PartialEq, + T: genetic_node::GeneticNode + + Default + + DeserializeOwned + + Serialize + + Clone + + PartialEq + + Debug, { /// Initializes a bracket of type `T` storing the contents to `file_path` /// @@ -180,7 +188,7 @@ where pub fn initialize(file_path: path::PathBuf) -> Result, Error> { Ok(FileLinked::new( Bracket { - tree: btree!(*T::initialize()?), + tree: btree!(Some(GeneticNodeWrapper::new()?)), iteration_scaling: IterationScaling::default(), }, file_path, @@ -199,7 +207,7 @@ where /// # use std::string::ToString; /// # use std::path; /// # - /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] + /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] /// # struct TestState { /// # pub score: f64, /// # } @@ -258,31 +266,36 @@ where // Creates a balanced tree with the given `height` that will be used as a branch of the primary tree. // This additionally simulates and evaluates nodes in the branch as it is built. - fn create_new_branch(&self, height: u64) -> Result, Error> { + fn create_new_branch( + &self, + height: u64, + ) -> Result>>, Error> { if height == 1 { - let mut base_node = btree!(*T::initialize()?); + let mut base_node = GeneticNodeWrapper::new()?; - base_node.val.simulate(match self.iteration_scaling { + base_node.process_node(match self.iteration_scaling { IterationScaling::Linear(x) => x * height, IterationScaling::Constant(x) => x, })?; - Ok(btree!(base_node.val)) + Ok(btree!(Some(base_node))) } else { let left = self.create_new_branch(height - 1)?; let right = self.create_new_branch(height - 1)?; - let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() { - left.val.clone() + let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score() + >= right.val.clone().unwrap().data.unwrap().get_fit_score() + { + left.val.clone().unwrap() } else { - right.val.clone() + right.val.clone().unwrap() }; - new_val.simulate(match self.iteration_scaling { + new_val.process_node(match self.iteration_scaling { IterationScaling::Linear(x) => x * height, IterationScaling::Constant(x) => x, })?; - Ok(btree!(new_val, left, right)) + Ok(btree!(Some(new_val), left, right)) } } @@ -302,7 +315,7 @@ where /// # use std::string::ToString; /// # use std::path; /// # - /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] + /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] /// # struct TestState { /// # pub score: f64, /// # } @@ -361,12 +374,24 @@ where pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> { let new_branch = self.create_new_branch(self.tree.height())?; - self.tree.val.simulate(match self.iteration_scaling { - IterationScaling::Linear(x) => (x * self.tree.height()), - IterationScaling::Constant(x) => x, - })?; + self.tree + .val + .clone() + .unwrap() + .process_node(match self.iteration_scaling { + IterationScaling::Linear(x) => (x * self.tree.height()), + IterationScaling::Constant(x) => x, + })?; - let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() { + let new_val = if new_branch + .val + .clone() + .unwrap() + .data + .unwrap() + .get_fit_score() + >= self.tree.val.clone().unwrap().data.unwrap().get_fit_score() + { new_branch.val.clone() } else { self.tree.val.clone() @@ -432,7 +457,7 @@ mod tests { file_linked::FileLinked::new( Bracket { tree: Tree { - val: TestState { score: 0.0 }, + val: Some(GeneticNodeWrapper::new().unwrap()), left: None, right: None }, @@ -460,44 +485,18 @@ mod tests { .expect("Failed to run step"); } + assert_eq!(bracket.readonly().tree.height(), 4); assert_eq!( - bracket, - file_linked::FileLinked::new( - Bracket { - iteration_scaling: IterationScaling::Linear(2), - tree: btree!( - TestState { score: 12.0 }, - btree!( - TestState { score: 12.0 }, - btree!( - TestState { score: 6.0 }, - btree!(TestState { score: 2.0 }), - btree!(TestState { score: 2.0 }) - ), - btree!( - TestState { score: 6.0 }, - btree!(TestState { score: 2.0 }), - btree!(TestState { score: 2.0 }) - ) - ), - btree!( - TestState { score: 12.0 }, - btree!( - TestState { score: 6.0 }, - btree!(TestState { score: 2.0 }), - btree!(TestState { score: 2.0 }) - ), - btree!( - TestState { score: 6.0 }, - btree!(TestState { score: 2.0 }), - btree!(TestState { score: 2.0 }) - ) - ) - ) - }, - path::PathBuf::from("./temp2") - ) - .unwrap() + bracket + .readonly() + .tree + .val + .clone() + .unwrap() + .data + .unwrap() + .score, + 15.0 ); std::fs::remove_file("./temp2").expect("Unable to remove file");