diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index 64d738a..c85d38a 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -1,13 +1,17 @@ -//! A trait used to interact with the internal state of nodes within the genetic bracket +//! A trait used to interact with the internal state of nodes within the [`Bracket`] +//! +//! [`Bracket`]: crate::bracket::Bracket use super::genetic_state::GeneticState; use serde::{Deserialize, Serialize}; use std::fmt; -/// A trait used to interact with the internal state of nodes within the genetic bracket +/// A trait used to interact with the internal state of nodes within the [`Bracket`] +/// +/// [`Bracket`]: crate::bracket::Bracket pub trait GeneticNode { - /// Initializes a new instance of a genetic state. + /// Initializes a new instance of a [`GeneticState`]. /// /// # Examples /// @@ -16,7 +20,7 @@ pub trait GeneticNode { /// # /// struct Node { /// pub fit_score: f64, - /// } + /// } /// /// impl GeneticNode for Node { /// fn initialize() -> Result, String> { @@ -50,10 +54,68 @@ pub trait GeneticNode { /// ``` fn initialize() -> Result, String>; - /// Runs a simulation on the state object in order to guage it's fitness. - /// - iterations: the number of iterations (learning cycles) that the current state should simulate - /// + /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. + /// + /// #Examples + /// + /// ``` + /// # use gemla::bracket::genetic_node::GeneticNode; + /// # + /// struct Model { + /// pub fit_score: f64, + /// //... + /// } + /// + /// struct Node { + /// pub fit_score: f64, + /// pub model: Model, + /// //... + /// } + /// + /// impl Model { + /// fn fit(&mut self, epochs: u64) -> Result<(), String> { + /// //... + /// # self.fit_score += epochs as f64; + /// # Ok(()) + /// } + /// } + /// + /// impl GeneticNode for Node { + /// # fn initialize() -> Result, String> { + /// # Ok(Box::new(Node {fit_score: 0.0, model: Model {fit_score: 0.0}})) + /// # } + /// # + /// //... + /// + /// fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// self.model.fit(iterations)?; + /// self.fit_score = self.model.fit_score; + /// Ok(()) + /// } + /// + /// //... + /// # + /// # fn get_fit_score(&self) -> f64 { + /// # self.fit_score + /// # } + /// # + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// } + /// + /// # fn main() -> Result<(), String> { + /// let mut node = Node::initialize()?; + /// (*node).simulate(5)?; + /// # assert_eq!(node.get_fit_score(), 5.0); + /// # Ok(()) + /// # } + /// ``` fn simulate(&mut self, iterations: u64) -> Result<(), String>; /// Returns a fit score associated with the nodes performance. @@ -67,23 +129,67 @@ pub trait GeneticNode { fn mutate(&mut self) -> Result<(), String>; } -/// Used externally to wrap a node implementing the GeneticNode trait. Processes state transitions for the given node as well as signal recovery. +/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as +/// well as signal recovery. Transition states are given by [`GeneticState`] #[derive(Serialize, Deserialize, Clone, Debug)] pub struct GeneticNodeWrapper where T: GeneticNode, { - data: Option, + pub data: Option, state: GeneticState, - iteration: u32, + pub iteration: u32, } impl GeneticNodeWrapper where T: GeneticNode + fmt::Debug, { - /// Initializes a wrapper around a GeneticNode - fn new() -> Result { + /// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to + /// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in + /// [`process_node`](#method.process_node). + /// + /// # Examples + /// ``` + /// # use gemla::bracket::genetic_node::GeneticNode; + /// # use gemla::bracket::genetic_node::GeneticNodeWrapper; + /// # #[derive(Debug)] + /// struct Node { + /// # pub fit_score: f64, + /// //... + /// } + /// + /// impl GeneticNode for Node { + /// //... + /// # fn initialize() -> Result, String> { + /// # Ok(Box::new(Node {fit_score: 0.0})) + /// # } + /// # + /// # + /// # fn simulate(&mut self, iterations: u64) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn get_fit_score(&self) -> f64 { + /// # self.fit_score + /// # } + /// # + /// # fn calculate_scores_and_trim(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// # + /// # fn mutate(&mut self) -> Result<(), String> { + /// # Ok(()) + /// # } + /// } + /// + /// # fn main() -> Result<(), String> { + /// let mut wrapped_node = GeneticNodeWrapper::::new()?; + /// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new() -> Result { let mut node = GeneticNodeWrapper { data: None, state: GeneticState::Initialize, @@ -97,9 +203,24 @@ where Ok(node) } - fn process_node(&mut self, iterations: u32) -> Result<(), String> { - let mut result = Ok(()); - + /// Performs state transitions on the [`GeneticNode`] wrapped by the [`GeneticNodeWrapper`]. + /// Will loop through the node training and scoring process for the given number of `iterations`. + /// + /// ## Transitions + /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change + /// the state to `GeneticState::Simulate` + /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` + /// - `GeneticState::Score`: Will call [`calculate_scores_and_trim`] and when the number of `iterations` have been reached will change + /// state to `GeneticState::Finish`, otherwise it will change the state to `GeneticState::Mutate. + /// - `GeneticState::Mutate`: Will call [`mutate`] and will change the state to `GeneticState::Simulate.` + /// - `GeneticState::Finish`: Will finish processing the node and return. + /// + /// [`initialize`]: crate::bracket::genetic_node::GeneticNode#tymethod.initialize + /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate + /// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim + /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate + pub fn process_node(&mut self, iterations: u32) -> Result<(), String> { + // Looping through each state transition until the number of iterations have been reached. loop { match (self.state, self.data.as_ref()) { (GeneticState::Initialize, _) => { @@ -141,10 +262,10 @@ where (GeneticState::Finish, Some(_)) => { break; } - _ => result = Err(format!("Error processing node {:?}", self.data)), + _ => return Err(format!("Error processing node {:?}", self.data)), } } - result + Ok(()) } }