From 569a17f145d34d00b9ab1ee9a51174f48c59c4d7 Mon Sep 17 00:00:00 2001 From: vandomej Date: Tue, 5 Oct 2021 16:29:59 -0700 Subject: [PATCH] Refactoring bracket interface --- file_linked/src/lib.rs | 71 ++--- gemla/src/bin/test_state/mod.rs | 18 -- gemla/src/bracket/genetic_node.rs | 142 +++------ gemla/src/bracket/mod.rs | 497 ++++-------------------------- gemla/src/error.rs | 8 + 5 files changed, 137 insertions(+), 599 deletions(-) diff --git a/file_linked/src/lib.rs b/file_linked/src/lib.rs index 95ff5d3..ff33531 100644 --- a/file_linked/src/lib.rs +++ b/file_linked/src/lib.rs @@ -2,12 +2,10 @@ extern crate serde; -use std::fs; -use std::io; +use std::fs::File; use std::io::prelude::*; -use std::path; - -use anyhow::{anyhow, Context}; +use std::path::PathBuf; +use anyhow::Context; use serde::de::DeserializeOwned; use serde::Serialize; use thiserror::Error; @@ -29,7 +27,7 @@ where T: Serialize, { val: T, - path: path::PathBuf, + path: PathBuf, } impl FileLinked @@ -44,7 +42,7 @@ where /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; - /// # use std::path; + /// # use std::path::PathBuf; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -60,7 +58,7 @@ where /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp")) /// .expect("Unable to create file linked object"); /// /// assert_eq!(linked_test.readonly().a, 1); @@ -82,7 +80,7 @@ where /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; - /// # use std::path; + /// # use std::path::PathBuf; /// # /// #[derive(Deserialize, Serialize)] /// struct Test { @@ -98,7 +96,7 @@ where /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp")) /// .expect("Unable to create file linked object"); /// /// assert_eq!(linked_test.readonly().a, 1); @@ -108,20 +106,14 @@ where /// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # } /// ``` - pub fn new(val: T, path: path::PathBuf) -> Result, Error> { - let result = FileLinked { val, path }; - + pub fn new(val: T, path: &PathBuf) -> Result, Error> { + let result = FileLinked { val, path: path.clone() }; result.write_data()?; - Ok(result) } fn write_data(&self) -> Result<(), Error> { - let mut file = fs::OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(&self.path) + let mut file = File::create(&self.path) .with_context(|| format!("Unable to open path {}", self.path.display()))?; write!( @@ -143,7 +135,7 @@ where /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; - /// # use std::path; + /// # use std::path::PathBuf; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -159,7 +151,7 @@ where /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp")) /// .expect("Unable to create file linked object"); /// /// assert_eq!(linked_test.readonly().a, 1); @@ -189,7 +181,7 @@ where /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; - /// # use std::path; + /// # use std::path::PathBuf; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -205,7 +197,7 @@ where /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp")) /// .expect("Unable to create file linked object"); /// /// assert_eq!(linked_test.readonly().a, 1); @@ -245,7 +237,7 @@ where /// # use std::fs; /// # use std::fs::OpenOptions; /// # use std::io::Write; - /// # use std::path; + /// # use std::path::PathBuf; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -261,7 +253,7 @@ where /// c: 3.0 /// }; /// - /// let path = path::PathBuf::from("./temp"); + /// let path = PathBuf::from("./temp"); /// /// let mut file = OpenOptions::new() /// .write(true) @@ -275,7 +267,7 @@ where /// /// drop(file); /// - /// let mut linked_test = FileLinked::::from_file(path) + /// let mut linked_test = FileLinked::::from_file(&path) /// .expect("Unable to create file linked object"); /// /// assert_eq!(linked_test.readonly().a, test.a); @@ -287,27 +279,14 @@ where /// # Ok(()) /// # } /// ``` - pub fn from_file(path: path::PathBuf) -> Result, Error> { - let metadata = path - .metadata() - .with_context(|| format!("Error obtaining metadata for {}", path.display()))?; + pub fn from_file(path: &PathBuf) -> Result, Error> { + let file = File::open(path) + .with_context(|| format!("Unable to open file {}", path.display()))?; - if metadata.is_file() { - let file = fs::OpenOptions::new() - .read(true) - .open(&path) - .with_context(|| format!("Unable to open file {}", path.display()))?; + let val = serde_json::from_reader(file) + .with_context(|| String::from("Unable to parse value from file."))?; - let val = serde_json::from_reader(file) - .with_context(|| String::from("Unable to parse value from file."))?; - - Ok(FileLinked { val, path }) - } else { - return Err(Error::IO(io::Error::new( - io::ErrorKind::Other, - anyhow!("{} is not a file.", path.display()), - ))); - } + Ok(FileLinked { val, path: path.clone() }) } } @@ -319,7 +298,7 @@ mod tests { #[test] fn test_mutate() -> Result<(), Error> { let list = vec![1, 2, 3, 4]; - let mut file_linked_list = FileLinked::new(list, path::PathBuf::from("test.txt"))?; + let mut file_linked_list = FileLinked::new(list, &PathBuf::from("test.txt"))?; assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]"); diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index eb404e1..10824c5 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -40,14 +40,6 @@ impl GeneticNode for TestState { Ok(()) } - fn get_fit_score(&self) -> f64 { - self.population - .clone() - .into_iter() - .reduce(f64::max) - .unwrap() - } - fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { let mut v = self.population.clone(); @@ -144,16 +136,6 @@ mod tests { .all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0)) } - #[test] - fn test_get_fit_score() { - let state = TestState { - thread_rng: thread_rng(), - population: vec![1.0, 1.0, 2.0, 3.0], - }; - - assert_eq!(state.get_fit_score(), 3.0); - } - #[test] fn test_calculate_scores_and_trim() { let mut state = TestState { diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index 9875d14..b6d5e20 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -53,10 +53,6 @@ pub trait GeneticNode { /// # Ok(()) /// # } /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.fit_score - /// # } - /// # /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } @@ -72,7 +68,7 @@ pub trait GeneticNode { /// /// # fn main() -> Result<(), Error> { /// let node = Node::initialize()?; - /// assert_eq!(node.get_fit_score(), 0.0); + /// assert_eq!(node.fit_score, 0.0); /// # Ok(()) /// # } /// ``` @@ -81,7 +77,7 @@ pub trait GeneticNode { /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. /// This will be called for every node in a bracket before evaluating it's fitness against other nodes. /// - /// #Examples + /// # Examples /// /// ``` /// # use gemla::bracket::genetic_node::GeneticNode; @@ -104,7 +100,17 @@ pub trait GeneticNode { /// # Ok(()) /// } /// } - /// + /// + /// # impl Node { + /// # fn get_fit_score(&self) -> f64 { + /// # self.models + /// # .iter() + /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) + /// # .unwrap() + /// # .fit_score + /// # } + /// # } + /// # /// impl GeneticNode for Node { /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) @@ -121,10 +127,6 @@ pub trait GeneticNode { /// } /// /// //... - /// - /// # fn get_fit_score(&self) -> f64 { - /// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score - /// # } /// # /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # Ok(()) @@ -148,76 +150,6 @@ pub trait GeneticNode { /// ``` fn simulate(&mut self, iterations: u64) -> Result<(), Error>; - /// Returns a fit score associated with the nodes performance. - /// This will be used by a bracket in order to determine the most successful child. - /// - /// # Examples - /// ``` - /// # use gemla::bracket::genetic_node::GeneticNode; - /// # use gemla::error::Error; - /// # - /// struct Model { - /// pub fit_score: f64, - /// //... - /// } - /// - /// struct Node { - /// pub models: Vec, - /// //... - /// } - /// - /// # impl Model { - /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { - /// # //... - /// # self.fit_score += epochs as f64; - /// # Ok(()) - /// # } - /// # } - /// - /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) - /// # } - /// # - /// # //... - /// # - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # for m in self.models.iter_mut() - /// # { - /// # m.fit(iterations)?; - /// # } - /// # Ok(()) - /// # } - /// # - /// //... - /// - /// fn get_fit_score(&self) -> f64 { - /// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score - /// } - /// - /// //... - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn merge(left: &Node, right: &Node) -> Result, Error> { - /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) - /// # } - /// } - /// - /// # fn main() -> Result<(), Error> { - /// let mut node = Node::initialize()?; - /// node.simulate(5)?; - /// assert_eq!(node.get_fit_score(), 5.0); - /// # Ok(()) - /// # } - /// ``` - fn get_fit_score(&self) -> f64; - /// Used when scoring the nodes after simulating and should remove underperforming children. /// /// # Examples @@ -239,11 +171,23 @@ pub trait GeneticNode { /// # impl Model { /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { /// # //... - /// # self.fit_score += epochs as f64; + /// # self.fit_score += epochs as f64; /// # Ok(()) /// # } /// # } - /// + /// # + /// # + /// # impl Node { + /// # fn get_fit_score(&self) -> f64 { + /// # self.models + /// # .iter() + /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) + /// # .unwrap() + /// # .fit_score + /// # } + /// # } + /// # + /// /// impl GeneticNode for Node { /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node { @@ -269,14 +213,6 @@ pub trait GeneticNode { /// # /// //... /// - /// # fn get_fit_score(&self) -> f64 { - /// # self.models - /// # .iter() - /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) - /// # .unwrap() - /// # .fit_score - /// # } - /// # /// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); /// self.models.truncate(3); @@ -326,6 +262,16 @@ pub trait GeneticNode { /// population_size: i64, /// //... /// } + /// # + /// # impl Node { + /// # fn get_fit_score(&self) -> f64 { + /// # self.models + /// # .iter() + /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) + /// # .unwrap() + /// # .fit_score + /// # } + /// # } /// /// # impl Model { /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { @@ -364,14 +310,6 @@ pub trait GeneticNode { /// # Ok(()) /// # } /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.models - /// # .iter() - /// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()) - /// # .unwrap() - /// # .fit_score - /// # } - /// # /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); /// # self.models.truncate(3); @@ -458,10 +396,6 @@ where /// # Ok(()) /// # } /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.fit_score - /// # } - /// # /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } @@ -477,7 +411,7 @@ where /// /// # fn main() -> Result<(), Error> { /// let mut wrapped_node = GeneticNodeWrapper::::new()?; - /// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0); + /// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0); /// # Ok(()) /// # } /// ``` diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 82ecbda..1774246 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -5,78 +5,21 @@ pub mod genetic_node; use crate::error::Error; use crate::tree; -use genetic_node::GeneticNodeWrapper; - +use genetic_node::{GeneticNodeWrapper, GeneticNode}; use file_linked::FileLinked; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::path; +use std::path::PathBuf; +use std::fs::File; +use std::io::ErrorKind; /// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that /// a node runs for. /// /// # Examples /// -/// ``` -/// # use gemla::bracket::*; -/// # use gemla::error::Error; -/// # use serde::{Deserialize, Serialize}; -/// # use std::fmt; -/// # use std::str::FromStr; -/// # use std::string::ToString; -/// # use std::path; -/// # -/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] -/// # struct TestState { -/// # pub score: f64, -/// # } -/// # -/// # impl TestState { -/// # fn new(score: f64) -> TestState { -/// # TestState { score: score } -/// # } -/// # } -/// # -/// # impl genetic_node::GeneticNode for TestState { -/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { -/// # self.score += iterations as f64; -/// # Ok(()) -/// # } -/// # -/// # fn get_fit_score(&self) -> f64 { -/// # self.score -/// # } -/// # -/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { -/// # Ok(()) -/// # } -/// # -/// # fn mutate(&mut self) -> Result<(), Error> { -/// # Ok(()) -/// # } -/// # -/// # fn initialize() -> Result, Error> { -/// # Ok(Box::new(TestState { score: 0.0 })) -/// # } -/// # -/// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { -/// # Ok(Box::new(left.clone())) -/// # } -/// # } -/// # -/// # fn main() { -/// let mut bracket = Bracket::::initialize(path::PathBuf::from("./temp")) -/// .expect("Bracket failed to initialize"); -/// -/// // Constant iteration scaling ensures that every node is simulated 5 times. -/// bracket -/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5)))) -/// .expect("Failed to set iteration scaling"); -/// -/// # std::fs::remove_file("./temp").expect("Unable to remove file"); -/// # } -/// ``` +/// TODO #[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)] #[serde(tag = "enumType", content = "enumContent")] pub enum IterationScaling { @@ -93,336 +36,87 @@ impl Default for IterationScaling { } } -/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. -/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building -/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest -/// individuals. -/// -/// [`GeneticNode`]: genetic_node::GeneticNode #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub struct Bracket +struct Bracket where - T: genetic_node::GeneticNode + Serialize, + T: GeneticNode + Serialize, { pub tree: tree::Tree>>, iteration_scaling: IterationScaling, } impl Bracket -where - T: genetic_node::GeneticNode - + Default - + DeserializeOwned - + Serialize - + Clone - + PartialEq - + Debug, +where T: GeneticNode + Serialize { - /// Initializes a bracket of type `T` storing the contents to `file_path` - /// - /// # Examples - /// ``` - /// # use gemla::bracket::*; - /// # use gemla::btree; - /// # use gemla::tree; - /// # use gemla::error::Error; - /// # use serde::{Deserialize, Serialize}; - /// # use std::fmt; - /// # use std::str::FromStr; - /// # use std::string::ToString; - /// # use std::path; - /// # - /// #[derive(Default, Deserialize, Serialize, Debug, Clone, PartialEq)] - /// struct TestState { - /// pub score: f64, - /// } - /// - /// # impl FromStr for TestState { - /// # type Err = String; - /// # - /// # fn from_str(s: &str) -> Result { - /// # serde_json::from_str(s).map_err(|_| format!("Unable to parse string {}", s)) - /// # } - /// # } - /// # - /// # impl fmt::Display for TestState { - /// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - /// # write!(f, "{}", self.score) - /// # } - /// # } - /// # - /// impl TestState { - /// fn new(score: f64) -> TestState { - /// TestState { score: score } - /// } - /// } - /// - /// impl genetic_node::GeneticNode for TestState { - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # self.score += iterations as f64; - /// # Ok(()) - /// # } - /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.score - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// fn initialize() -> Result, Error> { - /// Ok(Box::new(TestState { score: 0.0 })) - /// } - /// - /// //... - /// # - /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { - /// # Ok(Box::new(left.clone())) - /// # } - /// } - /// - /// # fn main() { - /// let mut bracket = Bracket::::initialize(path::PathBuf::from("./temp")) - /// .expect("Bracket failed to initialize"); - /// - /// std::fs::remove_file("./temp").expect("Unable to remove file"); - /// # } - /// ``` - pub fn initialize(file_path: path::PathBuf) -> Result, Error> { - Ok(FileLinked::new( - Bracket { - tree: btree!(Some(GeneticNodeWrapper::new()?)), - iteration_scaling: IterationScaling::default(), + fn increase_height(&mut self, _amount: usize) -> Result<(), Error> { + Ok(()) + } + + fn process_tree(&mut self) -> Result<(), Error> { + Ok(()) + } +} + +/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`]. +/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building +/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest +/// individuals. +/// +/// [`GeneticNode`]: genetic_node::GeneticNode +pub struct Gemla +where T: GeneticNode + Serialize + DeserializeOwned +{ + data: FileLinked> +} + +impl Gemla +where + T: GeneticNode + + Serialize + + DeserializeOwned + + Default +{ + pub fn new(path: &PathBuf, overwrite: bool) -> Result { + match File::open(path) { + Ok(file) => { + drop(file); + + Ok(Gemla { + data: + if overwrite { + FileLinked::from_file(path)? + } else { + FileLinked::new(Bracket { + tree: btree!(None), + iteration_scaling: IterationScaling::default() + }, path)? + } + }) }, - file_path, - )?) - } - - /// Given a bracket object, configures it's [`IterationScaling`]. - /// - /// # Examples - /// ``` - /// # use gemla::bracket::*; - /// # use gemla::error::Error; - /// # use serde::{Deserialize, Serialize}; - /// # use std::fmt; - /// # use std::str::FromStr; - /// # use std::string::ToString; - /// # use std::path; - /// # - /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] - /// # struct TestState { - /// # pub score: f64, - /// # } - /// # - /// # impl fmt::Display for TestState { - /// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - /// # write!(f, "{}", self.score) - /// # } - /// # } - /// # - /// # impl TestState { - /// # fn new(score: f64) -> TestState { - /// # TestState { score: score } - /// # } - /// # } - /// # - /// # impl genetic_node::GeneticNode for TestState { - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # self.score += iterations as f64; - /// # Ok(()) - /// # } - /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.score - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(TestState { score: 0.0 })) - /// # } - /// # - /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { - /// # Ok(Box::new(left.clone())) - /// # } - /// # } - /// # - /// # fn main() { - /// let mut bracket = Bracket::::initialize(path::PathBuf::from("./temp")) - /// .expect("Bracket failed to initialize"); - /// - /// // Constant iteration scaling ensures that every node is simulated 5 times. - /// bracket - /// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5)))) - /// .expect("Failed to set iteration scaling"); - /// - /// # std::fs::remove_file("./temp").expect("Unable to remove file"); - /// # } - /// ``` - pub fn iteration_scaling(&mut self, iteration_scaling: IterationScaling) -> &mut Self { - self.iteration_scaling = iteration_scaling; - self - } - - // Creates a balanced tree with the given `height` that will be used as a branch of the primary tree. - // This additionally simulates and evaluates nodes in the branch as it is built. - fn create_new_branch( - &self, - height: u64, - ) -> Result>>, Error> { - if height == 1 { - let mut base_node = GeneticNodeWrapper::new()?; - - base_node.process_node(match self.iteration_scaling { - IterationScaling::Linear(x) => x * height, - IterationScaling::Constant(x) => x, - })?; - - Ok(btree!(Some(base_node))) - } else { - let left = self.create_new_branch(height - 1)?; - let right = self.create_new_branch(height - 1)?; - let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score() - >= right.val.clone().unwrap().data.unwrap().get_fit_score() - { - left.val.clone().unwrap() - } else { - right.val.clone().unwrap() - }; - - new_val.process_node(match self.iteration_scaling { - IterationScaling::Linear(x) => x * height, - IterationScaling::Constant(x) => x, - })?; - - Ok(btree!(Some(new_val), left, right)) + Err(error) if error.kind() == ErrorKind::NotFound => { + Ok(Gemla { + data: FileLinked::new(Bracket { + tree: btree!(None), + iteration_scaling: IterationScaling::default() + }, path)? + }) + }, + Err(error) => Err(Error::IO(error)) } } - /// Runs one step of simulation on the current bracket which includes: - /// 1) Creating a new branch of the same height and performing the same steps for each subtree. - /// 2) Simulating the top node of the current branch. - /// 3) Comparing the top node of the current branch to the top node of the new branch. - /// 4) Takes the best performing node and makes it the root of the tree. - /// - /// # Examples - /// ``` - /// # use gemla::bracket::*; - /// # use gemla::error::Error; - /// # use serde::{Deserialize, Serialize}; - /// # use std::fmt; - /// # use std::str::FromStr; - /// # use std::string::ToString; - /// # use std::path; - /// # - /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)] - /// # struct TestState { - /// # pub score: f64, - /// # } - /// # - /// # impl fmt::Display for TestState { - /// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - /// # write!(f, "{}", self.score) - /// # } - /// # } - /// # - /// # impl TestState { - /// # fn new(score: f64) -> TestState { - /// # TestState { score: score } - /// # } - /// # } - /// # - /// # impl genetic_node::GeneticNode for TestState { - /// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { - /// # self.score += iterations as f64; - /// # Ok(()) - /// # } - /// # - /// # fn get_fit_score(&self) -> f64 { - /// # self.score - /// # } - /// # - /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn mutate(&mut self) -> Result<(), Error> { - /// # Ok(()) - /// # } - /// # - /// # fn initialize() -> Result, Error> { - /// # Ok(Box::new(TestState { score: 0.0 })) - /// # } - /// # - /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { - /// # Ok(Box::new(left.clone())) - /// # } - /// # } - /// # - /// # fn main() { - /// let mut bracket = Bracket::::initialize(path::PathBuf::from("./temp")) - /// .expect("Bracket failed to initialize"); - /// - /// // Running simulations 3 times - /// for _ in 0..3 { - /// bracket - /// .mutate(|b| drop(b.run_simulation_step())) - /// .expect("Failed to run step"); - /// } - /// - /// assert_eq!(bracket.readonly().tree.height(), 4); - /// - /// # std::fs::remove_file("./temp").expect("Unable to remove file"); - /// # } - /// ``` - pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> { - let new_branch = self.create_new_branch(self.tree.height())?; + pub fn simulate(&mut self, steps: u64) -> Result<(), Error> { + self.data.mutate(|b| b.increase_height(steps as usize))??; - self.tree - .val - .clone() - .unwrap() - .process_node(match self.iteration_scaling { - IterationScaling::Linear(x) => (x * self.tree.height()), - IterationScaling::Constant(x) => x, - })?; + self.data.mutate(|b| b.process_tree())??; - let new_val = if new_branch - .val - .clone() - .unwrap() - .data - .unwrap() - .get_fit_score() - >= self.tree.val.clone().unwrap().data.unwrap().get_fit_score() - { - new_branch.val.clone() - } else { - self.tree.val.clone() - }; - - self.tree = btree!(new_val, new_branch, self.tree.clone()); - - Ok(self) + Ok(()) } } #[cfg(test)] mod tests { use crate::bracket::*; - use crate::tree::*; use serde::{Deserialize, Serialize}; use std::str::FromStr; @@ -446,10 +140,6 @@ mod tests { Ok(()) } - fn get_fit_score(&self) -> f64 { - self.score - } - fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { Ok(()) } @@ -463,66 +153,11 @@ mod tests { } fn merge(left: &TestState, right: &TestState) -> Result, Error> { - Ok(Box::new(if left.get_fit_score() > right.get_fit_score() { + Ok(Box::new(if left.score > right.score { left.clone() } else { right.clone() })) } } - - #[test] - fn test_new() { - let bracket = Bracket::::initialize(path::PathBuf::from("./temp")) - .expect("Bracket failed to initialize"); - - assert_eq!( - bracket, - file_linked::FileLinked::new( - Bracket { - tree: Tree { - val: Some(GeneticNodeWrapper::new().unwrap()), - left: None, - right: None - }, - iteration_scaling: IterationScaling::Constant(1) - }, - path::PathBuf::from("./temp") - ) - .unwrap() - ); - - std::fs::remove_file("./temp").expect("Unable to remove file"); - } - - #[test] - fn test_run() { - let mut bracket = Bracket::::initialize(path::PathBuf::from("./temp2")) - .expect("Bracket failed to initialize"); - - bracket - .mutate(|b| drop(b.iteration_scaling(IterationScaling::Linear(2)))) - .expect("Failed to set iteration scaling"); - for _ in 0..3 { - bracket - .mutate(|b| drop(b.run_simulation_step())) - .expect("Failed to run step"); - } - - assert_eq!(bracket.readonly().tree.height(), 4); - assert_eq!( - bracket - .readonly() - .tree - .val - .clone() - .unwrap() - .data - .unwrap() - .score, - 15.0 - ); - - std::fs::remove_file("./temp2").expect("Unable to remove file"); - } } diff --git a/gemla/src/error.rs b/gemla/src/error.rs index 95f8858..947e824 100644 --- a/gemla/src/error.rs +++ b/gemla/src/error.rs @@ -5,6 +5,8 @@ pub enum Error { #[error(transparent)] FileLinked(file_linked::Error), #[error(transparent)] + IO(std::io::Error), + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -16,3 +18,9 @@ impl From for Error { } } } + +impl From for Error { + fn from(error: std::io::Error) -> Error { + Error::IO(error) + } +}