diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index 4d334e2..7235a43 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -7,24 +7,27 @@ use std::fmt; /// A trait used to interact with the internal state of nodes within the genetic bracket pub trait GeneticNode { + /// Initializes a new instance of a genetic state. + fn initialize() -> Result, String>; + /// Runs a simulation on the state object in order to guage it's fitness. /// - iterations: the number of iterations (learning cycles) that the current state should simulate /// /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. - fn simulate(&mut self, iterations: u64); + fn simulate(&mut self, iterations: u64) -> Result<(), String>; /// Returns a fit score associated with the nodes performance. /// This will be used by a bracket in order to determine the most successful child. fn get_fit_score(&self) -> f64; - fn calculate_scores_and_trim(&mut self); + /// Used when scoring the nodes after simulating and should remove underperforming children. + fn calculate_scores_and_trim(&mut self) -> Result<(), String>; - fn mutate(&mut self); - - /// Initializes a new instance of a genetic state. - fn initialize() -> Self; + /// Mutates members in a population and/or crossbreeds them to produce new offspring. + fn mutate(&mut self) -> Result<(), String>; } +/// Used externally to wrap a node implementing the GeneticNode trait. Processes state transitions for the given node as well as signal recovery. #[derive(Serialize, Deserialize, Clone, Debug)] pub struct GeneticNodeWrapper where @@ -39,17 +42,19 @@ impl GeneticNodeWrapper where T: GeneticNode + fmt::Debug, { - fn new() -> Self { + /// Initializes a wrapper around a GeneticNode + fn new() -> Result { let mut node = GeneticNodeWrapper { data: None, state: GeneticState::Initialize, iteration: 0, }; - node.data = Some(T::initialize()); + let new_data = T::initialize()?; + node.data = Some(*new_data); node.state = GeneticState::Simulate; - node + Ok(node) } fn process_node(&mut self, iterations: u32) -> Result<(), String> { @@ -59,15 +64,25 @@ where match (self.state, self.data.as_ref()) { (GeneticState::Initialize, _) => { self.iteration = 0; - self.data = Some(T::initialize()); + let new_data = + T::initialize().map_err(|e| format!("Error initializing node: {}", e))?; + self.data = Some(*new_data); self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(_)) => { - self.data.as_mut().unwrap().simulate(5); + self.data + .as_mut() + .unwrap() + .simulate(5) + .map_err(|e| format!("Error simulating node: {}", e))?; self.state = GeneticState::Score; } (GeneticState::Score, Some(_)) => { - self.data.as_mut().unwrap().calculate_scores_and_trim(); + self.data + .as_mut() + .unwrap() + .calculate_scores_and_trim() + .map_err(|e| format!("Error scoring and trimming node: {}", e))?; self.state = if self.iteration == iterations { GeneticState::Finish @@ -76,7 +91,11 @@ where } } (GeneticState::Mutate, Some(_)) => { - self.data.as_mut().unwrap().mutate(); + self.data + .as_mut() + .unwrap() + .mutate() + .map_err(|e| format!("Error mutating node: {}", e))?; self.state = GeneticState::Simulate; } (GeneticState::Finish, Some(_)) => { diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index b27b953..e487877 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -63,7 +63,7 @@ where pub fn initialize(file_path: String) -> Result, String> { FileLinked::new( Bracket { - tree: btree!(T::initialize()), + tree: btree!(*T::initialize()?), step: 0, iteration_scaling: IterationScaling::default(), }, @@ -76,18 +76,18 @@ where self } - pub fn create_new_branch(&self, height: u64) -> tree::Tree { + pub fn create_new_branch(&self, height: u64) -> Result, String> { if height == 1 { - let mut base_node = btree!(T::initialize()); + let mut base_node = btree!(*T::initialize()?); base_node.val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => (x as u64) * height, - }); + })?; - btree!(base_node.val) + Ok(btree!(base_node.val)) } else { - let left = self.create_new_branch(height - 1); - let right = self.create_new_branch(height - 1); + let left = self.create_new_branch(height - 1)?; + let right = self.create_new_branch(height - 1)?; let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() { left.val.clone() } else { @@ -96,18 +96,18 @@ where new_val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => (x as u64) * height, - }); + })?; - btree!(new_val, left, right) + Ok(btree!(new_val, left, right)) } } - pub fn run_simulation_step(&mut self) -> &mut Self { - let new_branch = self.create_new_branch(self.step + 1); + pub fn run_simulation_step(&mut self) -> Result<&mut Self, String> { + let new_branch = self.create_new_branch(self.step + 1)?; self.tree.val.simulate(match self.iteration_scaling { IterationScaling::Linear(x) => ((x as u64) * (self.step + 1)), - }); + })?; let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() { new_branch.val.clone() @@ -119,6 +119,6 @@ where self.step += 1; - self + Ok(self) } } diff --git a/gemla/src/tests/bracket.rs b/gemla/src/tests/bracket.rs index 92b9d71..c9deb81 100644 --- a/gemla/src/tests/bracket.rs +++ b/gemla/src/tests/bracket.rs @@ -31,20 +31,25 @@ impl TestState { } impl bracket::genetic_node::GeneticNode for TestState { - fn simulate(&mut self, iterations: u64) { + fn simulate(&mut self, iterations: u64) -> Result<(), String> { self.score += iterations as f64; + Ok(()) } fn get_fit_score(&self) -> f64 { self.score } - fn calculate_scores_and_trim(&mut self) {} + fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + Ok(()) + } - fn mutate(&mut self) {} + fn mutate(&mut self) -> Result<(), String> { + Ok(()) + } - fn initialize() -> Self { - TestState { score: 0.0 } + fn initialize() -> Result, String> { + Ok(Box::new(TestState { score: 0.0 })) } }