Refactoring bracket interface

This commit is contained in:
vandomej 2021-10-05 16:29:59 -07:00
parent 9081fb0b3c
commit 569a17f145
5 changed files with 137 additions and 599 deletions

View file

@ -2,12 +2,10 @@
extern crate serde; extern crate serde;
use std::fs; use std::fs::File;
use std::io;
use std::io::prelude::*; use std::io::prelude::*;
use std::path; use std::path::PathBuf;
use anyhow::Context;
use anyhow::{anyhow, Context};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
use thiserror::Error; use thiserror::Error;
@ -29,7 +27,7 @@ where
T: Serialize, T: Serialize,
{ {
val: T, val: T,
path: path::PathBuf, path: PathBuf,
} }
impl<T> FileLinked<T> impl<T> FileLinked<T>
@ -44,7 +42,7 @@ where
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path::PathBuf;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -60,7 +58,7 @@ where
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// assert_eq!(linked_test.readonly().a, 1);
@ -82,7 +80,7 @@ where
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path::PathBuf;
/// # /// #
/// #[derive(Deserialize, Serialize)] /// #[derive(Deserialize, Serialize)]
/// struct Test { /// struct Test {
@ -98,7 +96,7 @@ where
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// assert_eq!(linked_test.readonly().a, 1);
@ -108,20 +106,14 @@ where
/// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # } /// # }
/// ``` /// ```
pub fn new(val: T, path: path::PathBuf) -> Result<FileLinked<T>, Error> { pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
let result = FileLinked { val, path }; let result = FileLinked { val, path: path.clone() };
result.write_data()?; result.write_data()?;
Ok(result) Ok(result)
} }
fn write_data(&self) -> Result<(), Error> { fn write_data(&self) -> Result<(), Error> {
let mut file = fs::OpenOptions::new() let mut file = File::create(&self.path)
.write(true)
.create(true)
.truncate(true)
.open(&self.path)
.with_context(|| format!("Unable to open path {}", self.path.display()))?; .with_context(|| format!("Unable to open path {}", self.path.display()))?;
write!( write!(
@ -143,7 +135,7 @@ where
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path::PathBuf;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -159,7 +151,7 @@ where
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// assert_eq!(linked_test.readonly().a, 1);
@ -189,7 +181,7 @@ where
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path::PathBuf;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -205,7 +197,7 @@ where
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp")) /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, 1); /// assert_eq!(linked_test.readonly().a, 1);
@ -245,7 +237,7 @@ where
/// # use std::fs; /// # use std::fs;
/// # use std::fs::OpenOptions; /// # use std::fs::OpenOptions;
/// # use std::io::Write; /// # use std::io::Write;
/// # use std::path; /// # use std::path::PathBuf;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -261,7 +253,7 @@ where
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let path = path::PathBuf::from("./temp"); /// let path = PathBuf::from("./temp");
/// ///
/// let mut file = OpenOptions::new() /// let mut file = OpenOptions::new()
/// .write(true) /// .write(true)
@ -275,7 +267,7 @@ where
/// ///
/// drop(file); /// drop(file);
/// ///
/// let mut linked_test = FileLinked::<Test>::from_file(path) /// let mut linked_test = FileLinked::<Test>::from_file(&path)
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// assert_eq!(linked_test.readonly().a, test.a); /// assert_eq!(linked_test.readonly().a, test.a);
@ -287,27 +279,14 @@ where
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn from_file(path: path::PathBuf) -> Result<FileLinked<T>, Error> { pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
let metadata = path let file = File::open(path)
.metadata() .with_context(|| format!("Unable to open file {}", path.display()))?;
.with_context(|| format!("Error obtaining metadata for {}", path.display()))?;
if metadata.is_file() { let val = serde_json::from_reader(file)
let file = fs::OpenOptions::new() .with_context(|| String::from("Unable to parse value from file."))?;
.read(true)
.open(&path)
.with_context(|| format!("Unable to open file {}", path.display()))?;
let val = serde_json::from_reader(file) Ok(FileLinked { val, path: path.clone() })
.with_context(|| String::from("Unable to parse value from file."))?;
Ok(FileLinked { val, path })
} else {
return Err(Error::IO(io::Error::new(
io::ErrorKind::Other,
anyhow!("{} is not a file.", path.display()),
)));
}
} }
} }
@ -319,7 +298,7 @@ mod tests {
#[test] #[test]
fn test_mutate() -> Result<(), Error> { fn test_mutate() -> Result<(), Error> {
let list = vec![1, 2, 3, 4]; let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, path::PathBuf::from("test.txt"))?; let mut file_linked_list = FileLinked::new(list, &PathBuf::from("test.txt"))?;
assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]"); assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]");

View file

@ -40,14 +40,6 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
fn get_fit_score(&self) -> f64 {
self.population
.clone()
.into_iter()
.reduce(f64::max)
.unwrap()
}
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
let mut v = self.population.clone(); let mut v = self.population.clone();
@ -144,16 +136,6 @@ mod tests {
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0)) .all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0))
} }
#[test]
fn test_get_fit_score() {
let state = TestState {
thread_rng: thread_rng(),
population: vec![1.0, 1.0, 2.0, 3.0],
};
assert_eq!(state.get_fit_score(), 3.0);
}
#[test] #[test]
fn test_calculate_scores_and_trim() { fn test_calculate_scores_and_trim() {
let mut state = TestState { let mut state = TestState {

View file

@ -53,10 +53,6 @@ pub trait GeneticNode {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// # /// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -72,7 +68,7 @@ pub trait GeneticNode {
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
/// let node = Node::initialize()?; /// let node = Node::initialize()?;
/// assert_eq!(node.get_fit_score(), 0.0); /// assert_eq!(node.fit_score, 0.0);
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
@ -81,7 +77,7 @@ pub trait GeneticNode {
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness. /// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes. /// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
/// ///
/// #Examples /// # Examples
/// ///
/// ``` /// ```
/// # use gemla::bracket::genetic_node::GeneticNode; /// # use gemla::bracket::genetic_node::GeneticNode;
@ -105,6 +101,16 @@ pub trait GeneticNode {
/// } /// }
/// } /// }
/// ///
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
@ -121,10 +127,6 @@ pub trait GeneticNode {
/// } /// }
/// ///
/// //... /// //...
///
/// # fn get_fit_score(&self) -> f64 {
/// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// # }
/// # /// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
@ -148,76 +150,6 @@ pub trait GeneticNode {
/// ``` /// ```
fn simulate(&mut self, iterations: u64) -> Result<(), Error>; fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
/// Returns a fit score associated with the nodes performance.
/// This will be used by a bracket in order to determine the most successful child.
///
/// # Examples
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// //...
/// }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// #
/// # //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # for m in self.models.iter_mut()
/// # {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// //...
///
/// fn get_fit_score(&self) -> f64 {
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// }
///
/// //...
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut node = Node::initialize()?;
/// node.simulate(5)?;
/// assert_eq!(node.get_fit_score(), 5.0);
/// # Ok(())
/// # }
/// ```
fn get_fit_score(&self) -> f64;
/// Used when scoring the nodes after simulating and should remove underperforming children. /// Used when scoring the nodes after simulating and should remove underperforming children.
/// ///
/// # Examples /// # Examples
@ -239,10 +171,22 @@ pub trait GeneticNode {
/// # impl Model { /// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //... /// # //...
/// # self.fit_score += epochs as f64; /// # self.fit_score += epochs as f64;
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// # } /// # }
/// #
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
@ -269,14 +213,6 @@ pub trait GeneticNode {
/// # /// #
/// //... /// //...
/// ///
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); /// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// self.models.truncate(3); /// self.models.truncate(3);
@ -326,6 +262,16 @@ pub trait GeneticNode {
/// population_size: i64, /// population_size: i64,
/// //... /// //...
/// } /// }
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// ///
/// # impl Model { /// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> { /// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
@ -364,14 +310,6 @@ pub trait GeneticNode {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// # /// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse()); /// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// # self.models.truncate(3); /// # self.models.truncate(3);
@ -458,10 +396,6 @@ where
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// # /// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { /// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -477,7 +411,7 @@ where
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?; /// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
/// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0); /// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0);
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```

View file

@ -5,78 +5,21 @@ pub mod genetic_node;
use crate::error::Error; use crate::error::Error;
use crate::tree; use crate::tree;
use genetic_node::GeneticNodeWrapper; use genetic_node::{GeneticNodeWrapper, GeneticNode};
use file_linked::FileLinked; use file_linked::FileLinked;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Debug; use std::fmt::Debug;
use std::path; use std::path::PathBuf;
use std::fs::File;
use std::io::ErrorKind;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that /// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
/// a node runs for. /// a node runs for.
/// ///
/// # Examples /// # Examples
/// ///
/// ``` /// TODO
/// # use gemla::bracket::*;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Constant iteration scaling ensures that every node is simulated 5 times.
/// bracket
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
/// .expect("Failed to set iteration scaling");
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
#[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)] #[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)]
#[serde(tag = "enumType", content = "enumContent")] #[serde(tag = "enumType", content = "enumContent")]
pub enum IterationScaling { pub enum IterationScaling {
@ -93,336 +36,87 @@ impl Default for IterationScaling {
} }
} }
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Bracket<T> struct Bracket<T>
where where
T: genetic_node::GeneticNode + Serialize, T: GeneticNode + Serialize,
{ {
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>, pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
iteration_scaling: IterationScaling, iteration_scaling: IterationScaling,
} }
impl<T> Bracket<T> impl<T> Bracket<T>
where where T: GeneticNode + Serialize
T: genetic_node::GeneticNode
+ Default
+ DeserializeOwned
+ Serialize
+ Clone
+ PartialEq
+ Debug,
{ {
/// Initializes a bracket of type `T` storing the contents to `file_path` fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
/// Ok(())
/// # Examples }
/// ```
/// # use gemla::bracket::*; fn process_tree(&mut self) -> Result<(), Error> {
/// # use gemla::btree; Ok(())
/// # use gemla::tree; }
/// # use gemla::error::Error; }
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// # use std::str::FromStr; /// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// # use std::string::ToString; /// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// # use std::path; /// individuals.
/// # ///
/// #[derive(Default, Deserialize, Serialize, Debug, Clone, PartialEq)] /// [`GeneticNode`]: genetic_node::GeneticNode
/// struct TestState { pub struct Gemla<T>
/// pub score: f64, where T: GeneticNode + Serialize + DeserializeOwned
/// } {
/// data: FileLinked<Bracket<T>>
/// # impl FromStr for TestState { }
/// # type Err = String;
/// # impl<T> Gemla<T>
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> { where
/// # serde_json::from_str(s).map_err(|_| format!("Unable to parse string {}", s)) T: GeneticNode
/// # } + Serialize
/// # } + DeserializeOwned
/// # + Default
/// # impl fmt::Display for TestState { {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
/// # write!(f, "{}", self.score) match File::open(path) {
/// # } Ok(file) => {
/// # } drop(file);
/// #
/// impl TestState { Ok(Gemla {
/// fn new(score: f64) -> TestState { data:
/// TestState { score: score } if overwrite {
/// } FileLinked::from_file(path)?
/// } } else {
/// FileLinked::new(Bracket {
/// impl genetic_node::GeneticNode for TestState { tree: btree!(None),
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> { iteration_scaling: IterationScaling::default()
/// # self.score += iterations as f64; }, path)?
/// # Ok(()) }
/// # } })
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// fn initialize() -> Result<Box<Self>, Error> {
/// Ok(Box::new(TestState { score: 0.0 }))
/// }
///
/// //...
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// }
///
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
Ok(FileLinked::new(
Bracket {
tree: btree!(Some(GeneticNodeWrapper::new()?)),
iteration_scaling: IterationScaling::default(),
}, },
file_path, Err(error) if error.kind() == ErrorKind::NotFound => {
)?) Ok(Gemla {
} data: FileLinked::new(Bracket {
tree: btree!(None),
/// Given a bracket object, configures it's [`IterationScaling`]. iteration_scaling: IterationScaling::default()
/// }, path)?
/// # Examples })
/// ``` },
/// # use gemla::bracket::*; Err(error) => Err(Error::IO(error))
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Constant iteration scaling ensures that every node is simulated 5 times.
/// bracket
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
/// .expect("Failed to set iteration scaling");
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn iteration_scaling(&mut self, iteration_scaling: IterationScaling) -> &mut Self {
self.iteration_scaling = iteration_scaling;
self
}
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
// This additionally simulates and evaluates nodes in the branch as it is built.
fn create_new_branch(
&self,
height: u64,
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
if height == 1 {
let mut base_node = GeneticNodeWrapper::new()?;
base_node.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(Some(base_node)))
} else {
let left = self.create_new_branch(height - 1)?;
let right = self.create_new_branch(height - 1)?;
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
{
left.val.clone().unwrap()
} else {
right.val.clone().unwrap()
};
new_val.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(Some(new_val), left, right))
} }
} }
/// Runs one step of simulation on the current bracket which includes: pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
/// 1) Creating a new branch of the same height and performing the same steps for each subtree. self.data.mutate(|b| b.increase_height(steps as usize))??;
/// 2) Simulating the top node of the current branch.
/// 3) Comparing the top node of the current branch to the top node of the new branch.
/// 4) Takes the best performing node and makes it the root of the tree.
///
/// # Examples
/// ```
/// # use gemla::bracket::*;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Running simulations 3 times
/// for _ in 0..3 {
/// bracket
/// .mutate(|b| drop(b.run_simulation_step()))
/// .expect("Failed to run step");
/// }
///
/// assert_eq!(bracket.readonly().tree.height(), 4);
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
let new_branch = self.create_new_branch(self.tree.height())?;
self.tree self.data.mutate(|b| b.process_tree())??;
.val
.clone()
.unwrap()
.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x,
})?;
let new_val = if new_branch Ok(())
.val
.clone()
.unwrap()
.data
.unwrap()
.get_fit_score()
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
{
new_branch.val.clone()
} else {
self.tree.val.clone()
};
self.tree = btree!(new_val, new_branch, self.tree.clone());
Ok(self)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::bracket::*; use crate::bracket::*;
use crate::tree::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr; use std::str::FromStr;
@ -446,10 +140,6 @@ mod tests {
Ok(()) Ok(())
} }
fn get_fit_score(&self) -> f64 {
self.score
}
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> { fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
Ok(()) Ok(())
} }
@ -463,66 +153,11 @@ mod tests {
} }
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> { fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.get_fit_score() > right.get_fit_score() { Ok(Box::new(if left.score > right.score {
left.clone() left.clone()
} else { } else {
right.clone() right.clone()
})) }))
} }
} }
#[test]
fn test_new() {
let bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
.expect("Bracket failed to initialize");
assert_eq!(
bracket,
file_linked::FileLinked::new(
Bracket {
tree: Tree {
val: Some(GeneticNodeWrapper::new().unwrap()),
left: None,
right: None
},
iteration_scaling: IterationScaling::Constant(1)
},
path::PathBuf::from("./temp")
)
.unwrap()
);
std::fs::remove_file("./temp").expect("Unable to remove file");
}
#[test]
fn test_run() {
let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp2"))
.expect("Bracket failed to initialize");
bracket
.mutate(|b| drop(b.iteration_scaling(IterationScaling::Linear(2))))
.expect("Failed to set iteration scaling");
for _ in 0..3 {
bracket
.mutate(|b| drop(b.run_simulation_step()))
.expect("Failed to run step");
}
assert_eq!(bracket.readonly().tree.height(), 4);
assert_eq!(
bracket
.readonly()
.tree
.val
.clone()
.unwrap()
.data
.unwrap()
.score,
15.0
);
std::fs::remove_file("./temp2").expect("Unable to remove file");
}
} }

View file

@ -5,6 +5,8 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
FileLinked(file_linked::Error), FileLinked(file_linked::Error),
#[error(transparent)] #[error(transparent)]
IO(std::io::Error),
#[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
@ -16,3 +18,9 @@ impl From<file_linked::Error> for Error {
} }
} }
} }
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Error {
Error::IO(error)
}
}