diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index f4997ec..0b3f9ee 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -5,9 +5,10 @@ extern crate gemla; mod test_state; use clap::App; -use gemla::bracket::Gemla; +use gemla::core::{Gemla, GemlaConfig}; use std::path::PathBuf; use test_state::TestState; +// use std::io::Write; /// Runs a simluation of a genetic algorithm against a dataset. /// @@ -21,9 +22,18 @@ fn main() -> anyhow::Result<()> { // Checking that the first argument is a valid directory let file_path = matches.value_of(gemla::constants::args::FILE).unwrap(); - let mut gemla = Gemla::::new(&PathBuf::from(file_path), true)?; + let mut gemla = Gemla::::new( + &PathBuf::from(file_path), + GemlaConfig { + generations_per_node: 10, + overwrite: true, + }, + )?; - gemla.simulate(17)?; + gemla.simulate(10)?; + + // let mut f = std::fs::File::create("./test")?; + // write!(f, "{}", serde_json::to_string(&gemla.data.readonly().0)?)?; Ok(()) } diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index db6c2ba..9d2c599 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -1,4 +1,4 @@ -use gemla::bracket::genetic_node::GeneticNode; +use gemla::core::genetic_node::GeneticNode; use gemla::error; use rand::prelude::*; use rand::thread_rng; @@ -18,22 +18,20 @@ impl GeneticNode for TestState { let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { - population.push(thread_rng().gen_range(0..10000)) + population.push(thread_rng().gen_range(0..100)) } Ok(Box::new(TestState { population })) } - fn simulate(&mut self, iterations: u64) -> Result<(), error::Error> { + fn simulate(&mut self) -> Result<(), error::Error> { let mut rng = thread_rng(); - for _ in 0..iterations { - self.population = self - .population - .iter() - .map(|p| p + rng.gen_range(-10..10)) - .collect() - } + self.population = self + .population + .iter() + .map(|p| p.saturating_add(rng.gen_range(-1..2))) + .collect(); Ok(()) } @@ -67,7 +65,8 @@ impl GeneticNode for TestState { let mut new_individual = self.population.clone()[new_individual_index]; let cross_breed = self.population.clone()[cross_breed_index]; - new_individual += cross_breed + rng.gen_range(-10..10); + new_individual = (new_individual.saturating_add(cross_breed) / 2) + .saturating_add(rng.gen_range(-1..2)); self.population.push(new_individual); } @@ -95,7 +94,7 @@ impl GeneticNode for TestState { #[cfg(test)] mod tests { use super::*; - use gemla::bracket::genetic_node::GeneticNode; + use gemla::core::genetic_node::GeneticNode; #[test] fn test_initialize() { @@ -112,20 +111,18 @@ mod tests { let original_population = state.population.clone(); - state.simulate(0).unwrap(); - assert_eq!(original_population, state.population); - - state.simulate(1).unwrap(); + state.simulate().unwrap(); assert!(original_population .iter() .zip(state.population.iter()) - .all(|(&a, &b)| b >= a - 10 && b <= a + 10)); + .all(|(&a, &b)| b >= a - 1 && b <= a + 2)); - state.simulate(2).unwrap(); + state.simulate().unwrap(); + state.simulate().unwrap(); assert!(original_population .iter() .zip(state.population.iter()) - .all(|(&a, &b)| b >= a - 30 && b <= a + 30)) + .all(|(&a, &b)| b >= a - 3 && b <= a + 6)) } #[test] diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs deleted file mode 100644 index 08af901..0000000 --- a/gemla/src/bracket/genetic_node.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! A trait used to interact with the internal state of nodes within the [`Bracket`] -//! -//! [`Bracket`]: crate::bracket::Bracket - -use crate::error::Error; - -use anyhow::Context; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; - -/// An enum used to control the state of a [`GeneticNode`] -/// -/// [`GeneticNode`]: crate::bracket::genetic_node -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "enumType", content = "enumContent")] -pub enum GeneticState { - /// The node and it's data have not finished initializing - Initialize, - /// The node is currently simulating a round against target data to determine the fitness of the population - Simulate, - /// The node is currently mutating members of it's population and breeding new members - Mutate, - /// The node has finished processing for a given number of iterations - Finish, -} - -/// A trait used to interact with the internal state of nodes within the [`Bracket`] -/// -/// [`Bracket`]: crate::bracket::Bracket -pub trait GeneticNode { - /// Initializes a new instance of a [`GeneticState`]. - /// - /// # Examples - /// TODO - fn initialize() -> Result, Error>; - - /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. - /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. - /// - /// # Examples - /// TODO - fn simulate(&mut self, iterations: u64) -> Result<(), Error>; - - /// Mutates members in a population and/or crossbreeds them to produce new offspring. - /// - /// # Examples - /// TODO - fn mutate(&mut self) -> Result<(), Error>; - - fn merge(left: &Self, right: &Self) -> Result, Error>; -} - -/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as -/// well as signal recovery. Transition states are given by [`GeneticState`] -#[derive(Debug, Serialize, Deserialize)] -pub struct GeneticNodeWrapper { - pub data: Option, - state: GeneticState, - pub iteration: u64, -} - -impl GeneticNodeWrapper -where - T: GeneticNode + Debug, -{ - /// Initializes a wrapper around a GeneticNode. If the initialization is successful the internal state will be changed to - /// `GeneticState::Simulate` otherwise it will remain as `GeneticState::Initialize` and will attempt to be created in - /// [`process_node`](#method.process_node). - /// - /// # Examples - /// TODO - pub fn new() -> Result { - let mut node = GeneticNodeWrapper { - data: None, - state: GeneticState::Initialize, - iteration: 0, - }; - - let new_data = T::initialize()?; - node.data = Some(*new_data); - node.state = GeneticState::Simulate; - - Ok(node) - } - - pub fn from(data: T) -> Result { - let mut node = GeneticNodeWrapper { - data: Some(data), - state: GeneticState::Initialize, - iteration: 0, - }; - - node.state = GeneticState::Simulate; - - Ok(node) - } - - /// Performs state transitions on the [`GeneticNode`] wrapped by the [`GeneticNodeWrapper`]. - /// Will loop through the node training and scoring process for the given number of `iterations`. - /// - /// ## Transitions - /// - `GeneticState::Initialize`: will attempt to call [`initialize`] on the node. When done successfully will change - /// the state to `GeneticState::Simulate` - /// - `GeneticState::Simulate`: Will call [`simulate`] with a number of iterations (not for `iterations`). Will change the state to `GeneticState::Score` - /// - `GeneticState::Mutate`: Will call [`mutate`] and will change the state to `GeneticState::Simulate.` - /// - `GeneticState::Finish`: Will finish processing the node and return. - /// - /// [`initialize`]: crate::bracket::genetic_node::GeneticNode#tymethod.initialize - /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate - /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate - pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> { - // Looping through each state transition until the number of iterations have been reached. - loop { - match (&self.state, &self.data) { - (GeneticState::Initialize, _) => { - self.iteration = 0; - let new_data = T::initialize() - .with_context(|| format!("Error initializing node {:?}", self))?; - self.data = Some(*new_data); - self.state = GeneticState::Simulate; - } - (GeneticState::Simulate, Some(_)) => { - self.data - .as_mut() - .unwrap() - .simulate(5) - .with_context(|| format!("Error simulating node: {:?}", self))?; - - self.state = if self.iteration == iterations { - GeneticState::Finish - } else { - GeneticState::Mutate - }; - } - (GeneticState::Mutate, Some(_)) => { - self.data - .as_mut() - .unwrap() - .mutate() - .with_context(|| format!("Error mutating node: {:?}", self))?; - - self.iteration += 1; - self.state = GeneticState::Simulate; - } - (GeneticState::Finish, Some(_)) => { - break; - } - _ => panic!("Error processing node {:?}", self.data), - } - } - - Ok(()) - } -} diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs new file mode 100644 index 0000000..7ba672b --- /dev/null +++ b/gemla/src/core/genetic_node.rs @@ -0,0 +1,119 @@ +//! A trait used to interact with the internal state of nodes within the [`Bracket`] +//! +//! [`Bracket`]: crate::bracket::Bracket + +use crate::error::Error; + +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// An enum used to control the state of a [`GeneticNode`] +/// +/// [`GeneticNode`]: crate::bracket::genetic_node +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)] +#[serde(tag = "enumType", content = "enumContent")] +pub enum GeneticState { + /// The node and it's data have not finished initializing + Initialize, + /// The node is currently simulating a round against target data to determine the fitness of the population + Simulate, + /// The node is currently mutating members of it's population and breeding new members + Mutate, + /// The node has finished processing for a given number of iterations + Finish, +} + +/// A trait used to interact with the internal state of nodes within the [`Bracket`] +/// +/// [`Bracket`]: crate::bracket::Bracket +pub trait GeneticNode { + /// Initializes a new instance of a [`GeneticState`]. + /// + /// # Examples + /// TODO + fn initialize() -> Result, Error>; + + fn simulate(&mut self) -> Result<(), Error>; + + /// Mutates members in a population and/or crossbreeds them to produce new offspring. + /// + /// # Examples + /// TODO + fn mutate(&mut self) -> Result<(), Error>; + + fn merge(left: &Self, right: &Self) -> Result, Error>; +} + +/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as +/// well as signal recovery. Transition states are given by [`GeneticState`] +#[derive(Debug, Serialize, Deserialize)] +pub struct GeneticNodeWrapper { + pub node: Option, + state: GeneticState, + generation: u64, + pub total_generations: u64, +} + +impl GeneticNodeWrapper +where + T: GeneticNode + Debug, +{ + pub fn new(total_generations: u64) -> Self { + GeneticNodeWrapper { + node: None, + state: GeneticState::Initialize, + generation: 0, + total_generations, + } + } + + pub fn from(data: T, total_generations: u64) -> Self { + GeneticNodeWrapper { + node: Some(data), + state: GeneticState::Simulate, + generation: 0, + total_generations, + } + } + + pub fn state(&self) -> &GeneticState { + &self.state + } + + pub fn process_node(&mut self) -> Result { + match (&self.state, &self.node) { + (GeneticState::Initialize, _) => { + self.node = Some(*T::initialize()?); + self.state = GeneticState::Simulate; + } + (GeneticState::Simulate, Some(_)) => { + self.node + .as_mut() + .unwrap() + .simulate() + .with_context(|| format!("Error simulating node: {:?}", self))?; + + self.state = if self.generation >= self.total_generations { + GeneticState::Finish + } else { + GeneticState::Mutate + }; + } + (GeneticState::Mutate, Some(_)) => { + self.node + .as_mut() + .unwrap() + .mutate() + .with_context(|| format!("Error mutating node: {:?}", self))?; + + self.generation += 1; + self.state = GeneticState::Simulate; + } + (GeneticState::Finish, Some(_)) => (), + _ => panic!("Error processing node {:?}", self.node), + } + + Ok(self.state) + } +} diff --git a/gemla/src/bracket/mod.rs b/gemla/src/core/mod.rs similarity index 56% rename from gemla/src/bracket/mod.rs rename to gemla/src/core/mod.rs index 5096375..83f3608 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/core/mod.rs @@ -7,15 +7,23 @@ use crate::error::Error; use crate::tree::Tree; use anyhow::anyhow; use file_linked::FileLinked; -use genetic_node::{GeneticNode, GeneticNodeWrapper}; +use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use serde::de::DeserializeOwned; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::fs::File; use std::io::ErrorKind; use std::mem::swap; use std::path::Path; +type SimulationTree = Tree>; + +#[derive(Serialize, Deserialize)] +pub struct GemlaConfig { + pub generations_per_node: u64, + pub overwrite: bool, +} + /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. /// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building /// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest @@ -26,91 +34,106 @@ pub struct Gemla where T: Serialize, { - data: FileLinked>>>>, + pub data: FileLinked<(Option>, GemlaConfig)>, } impl Gemla where T: GeneticNode + Serialize + DeserializeOwned + Debug, { - pub fn new(path: &Path, overwrite: bool) -> Result { + pub fn new(path: &Path, config: GemlaConfig) -> Result { match File::open(path) { Ok(file) => { drop(file); Ok(Gemla { - data: if overwrite { - FileLinked::new(Some(btree!(None)), path)? + data: if config.overwrite { + FileLinked::new((None, config), path)? } else { FileLinked::from_file(path)? }, }) } Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { - data: FileLinked::new(Some(btree!(None)), path)?, + data: FileLinked::new((None, config), path)?, }), Err(error) => Err(Error::IO(error)), } } pub fn simulate(&mut self, steps: u64) -> Result<(), Error> { - self.data.mutate(|d| Gemla::increase_height(d, steps))?; + self.data + .mutate(|(d, c)| Gemla::increase_height(d, c, steps))??; self.data - .mutate(|d| Gemla::process_tree(d.as_mut().unwrap()))??; + .mutate(|(d, _c)| Gemla::process_tree(d.as_mut().unwrap()))??; Ok(()) } - fn build_empty_tree(size: usize) -> Tree>> { - if size <= 1 { - btree!(None) - } else { - btree!( - None, - Gemla::build_empty_tree(size - 1), - Gemla::build_empty_tree(size - 1) - ) - } - } - - fn increase_height(tree: &mut Option>>>, amount: u64) { + fn increase_height( + tree: &mut Option>, + config: &GemlaConfig, + amount: u64, + ) -> Result<(), Error> { for _ in 0..amount { - let height = tree.as_ref().unwrap().height(); - let temp = tree.take(); - swap( - tree, - &mut Some(btree!( - None, - temp.unwrap(), - Gemla::build_empty_tree(height as usize) - )), - ); + if tree.is_none() { + swap( + tree, + &mut Some(btree!(GeneticNodeWrapper::new(config.generations_per_node))), + ); + } else { + let height = tree.as_mut().unwrap().height() as u64; + let temp = tree.take(); + swap( + tree, + &mut Some(btree!( + GeneticNodeWrapper::new(config.generations_per_node), + temp.unwrap(), + btree!(GeneticNodeWrapper::new( + height * config.generations_per_node + )) + )), + ); + } } + + Ok(()) } - fn process_tree(tree: &mut Tree>>) -> Result<(), Error> { - if tree.val.is_none() { + fn process_tree(tree: &mut SimulationTree) -> Result<(), Error> { + if tree.val.state() == &GeneticState::Initialize { match (&mut tree.left, &mut tree.right) { (Some(l), Some(r)) => { Gemla::process_tree(&mut (*l))?; Gemla::process_tree(&mut (*r))?; - let left_node = (*l).val.as_ref().unwrap().data.as_ref().unwrap(); - let right_node = (*r).val.as_ref().unwrap().data.as_ref().unwrap(); + let left_node = (*l).val.node.as_ref().unwrap(); + let right_node = (*r).val.node.as_ref().unwrap(); let merged_node = GeneticNode::merge(left_node, right_node)?; - tree.val = Some(GeneticNodeWrapper::from(*merged_node)?); - tree.val.as_mut().unwrap().process_node(1)?; + tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.total_generations); + Gemla::process_node(&mut tree.val)?; } (None, None) => { - tree.val = Some(GeneticNodeWrapper::new()?); - tree.val.as_mut().unwrap().process_node(1)?; + Gemla::process_node(&mut tree.val)?; } _ => { return Err(Error::Other(anyhow!("unable to process tree {:?}", tree))); } } + } else { + Gemla::process_node(&mut tree.val)?; + } + + Ok(()) + } + + fn process_node(node: &mut GeneticNodeWrapper) -> Result<(), Error> { + loop { + if node.process_node()? == GeneticState::Finish { + break; + } } Ok(()) @@ -119,7 +142,7 @@ where #[cfg(test)] mod tests { - use crate::bracket::*; + use crate::core::*; use serde::{Deserialize, Serialize}; use std::str::FromStr; @@ -138,8 +161,8 @@ mod tests { } impl genetic_node::GeneticNode for TestState { - fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - self.score += iterations as f64; + fn simulate(&mut self) -> Result<(), Error> { + self.score += 1.0; Ok(()) } diff --git a/gemla/src/lib.rs b/gemla/src/lib.rs index 187d872..c40b530 100644 --- a/gemla/src/lib.rs +++ b/gemla/src/lib.rs @@ -3,6 +3,6 @@ extern crate regex; #[macro_use] pub mod tree; -pub mod bracket; pub mod constants; +pub mod core; pub mod error;