Refactoring bracket interface

This commit is contained in:
vandomej 2021-10-05 16:29:59 -07:00
parent 9081fb0b3c
commit 569a17f145
5 changed files with 137 additions and 599 deletions

View file

@ -2,12 +2,10 @@
extern crate serde;
use std::fs;
use std::io;
use std::fs::File;
use std::io::prelude::*;
use std::path;
use anyhow::{anyhow, Context};
use std::path::PathBuf;
use anyhow::Context;
use serde::de::DeserializeOwned;
use serde::Serialize;
use thiserror::Error;
@ -29,7 +27,7 @@ where
T: Serialize,
{
val: T,
path: path::PathBuf,
path: PathBuf,
}
impl<T> FileLinked<T>
@ -44,7 +42,7 @@ where
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path;
/// # use std::path::PathBuf;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -60,7 +58,7 @@ where
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
@ -82,7 +80,7 @@ where
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path;
/// # use std::path::PathBuf;
/// #
/// #[derive(Deserialize, Serialize)]
/// struct Test {
@ -98,7 +96,7 @@ where
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
@ -108,20 +106,14 @@ where
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn new(val: T, path: path::PathBuf) -> Result<FileLinked<T>, Error> {
let result = FileLinked { val, path };
pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
let result = FileLinked { val, path: path.clone() };
result.write_data()?;
Ok(result)
}
fn write_data(&self) -> Result<(), Error> {
let mut file = fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&self.path)
let mut file = File::create(&self.path)
.with_context(|| format!("Unable to open path {}", self.path.display()))?;
write!(
@ -143,7 +135,7 @@ where
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path;
/// # use std::path::PathBuf;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -159,7 +151,7 @@ where
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
@ -189,7 +181,7 @@ where
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path;
/// # use std::path::PathBuf;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -205,7 +197,7 @@ where
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
@ -245,7 +237,7 @@ where
/// # use std::fs;
/// # use std::fs::OpenOptions;
/// # use std::io::Write;
/// # use std::path;
/// # use std::path::PathBuf;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -261,7 +253,7 @@ where
/// c: 3.0
/// };
///
/// let path = path::PathBuf::from("./temp");
/// let path = PathBuf::from("./temp");
///
/// let mut file = OpenOptions::new()
/// .write(true)
@ -275,7 +267,7 @@ where
///
/// drop(file);
///
/// let mut linked_test = FileLinked::<Test>::from_file(path)
/// let mut linked_test = FileLinked::<Test>::from_file(&path)
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, test.a);
@ -287,27 +279,14 @@ where
/// # Ok(())
/// # }
/// ```
pub fn from_file(path: path::PathBuf) -> Result<FileLinked<T>, Error> {
let metadata = path
.metadata()
.with_context(|| format!("Error obtaining metadata for {}", path.display()))?;
pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
let file = File::open(path)
.with_context(|| format!("Unable to open file {}", path.display()))?;
if metadata.is_file() {
let file = fs::OpenOptions::new()
.read(true)
.open(&path)
.with_context(|| format!("Unable to open file {}", path.display()))?;
let val = serde_json::from_reader(file)
.with_context(|| String::from("Unable to parse value from file."))?;
let val = serde_json::from_reader(file)
.with_context(|| String::from("Unable to parse value from file."))?;
Ok(FileLinked { val, path })
} else {
return Err(Error::IO(io::Error::new(
io::ErrorKind::Other,
anyhow!("{} is not a file.", path.display()),
)));
}
Ok(FileLinked { val, path: path.clone() })
}
}
@ -319,7 +298,7 @@ mod tests {
#[test]
fn test_mutate() -> Result<(), Error> {
let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, path::PathBuf::from("test.txt"))?;
let mut file_linked_list = FileLinked::new(list, &PathBuf::from("test.txt"))?;
assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]");

View file

@ -40,14 +40,6 @@ impl GeneticNode for TestState {
Ok(())
}
fn get_fit_score(&self) -> f64 {
self.population
.clone()
.into_iter()
.reduce(f64::max)
.unwrap()
}
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
let mut v = self.population.clone();
@ -144,16 +136,6 @@ mod tests {
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0))
}
#[test]
fn test_get_fit_score() {
let state = TestState {
thread_rng: thread_rng(),
population: vec![1.0, 1.0, 2.0, 3.0],
};
assert_eq!(state.get_fit_score(), 3.0);
}
#[test]
fn test_calculate_scores_and_trim() {
let mut state = TestState {

View file

@ -53,10 +53,6 @@ pub trait GeneticNode {
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
@ -72,7 +68,7 @@ pub trait GeneticNode {
///
/// # fn main() -> Result<(), Error> {
/// let node = Node::initialize()?;
/// assert_eq!(node.get_fit_score(), 0.0);
/// assert_eq!(node.fit_score, 0.0);
/// # Ok(())
/// # }
/// ```
@ -81,7 +77,7 @@ pub trait GeneticNode {
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
///
/// #Examples
/// # Examples
///
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
@ -104,7 +100,17 @@ pub trait GeneticNode {
/// # Ok(())
/// }
/// }
///
///
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
@ -121,10 +127,6 @@ pub trait GeneticNode {
/// }
///
/// //...
///
/// # fn get_fit_score(&self) -> f64 {
/// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
@ -148,76 +150,6 @@ pub trait GeneticNode {
/// ```
fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
/// Returns a fit score associated with the nodes performance.
/// This will be used by a bracket in order to determine the most successful child.
///
/// # Examples
/// ```
/// # use gemla::bracket::genetic_node::GeneticNode;
/// # use gemla::error::Error;
/// #
/// struct Model {
/// pub fit_score: f64,
/// //...
/// }
///
/// struct Node {
/// pub models: Vec<Model>,
/// //...
/// }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// #
/// # //...
/// #
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # for m in self.models.iter_mut()
/// # {
/// # m.fit(iterations)?;
/// # }
/// # Ok(())
/// # }
/// #
/// //...
///
/// fn get_fit_score(&self) -> f64 {
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
/// }
///
/// //...
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// }
///
/// # fn main() -> Result<(), Error> {
/// let mut node = Node::initialize()?;
/// node.simulate(5)?;
/// assert_eq!(node.get_fit_score(), 5.0);
/// # Ok(())
/// # }
/// ```
fn get_fit_score(&self) -> f64;
/// Used when scoring the nodes after simulating and should remove underperforming children.
///
/// # Examples
@ -239,11 +171,23 @@ pub trait GeneticNode {
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
/// # //...
/// # self.fit_score += epochs as f64;
/// # self.fit_score += epochs as f64;
/// # Ok(())
/// # }
/// # }
///
/// #
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
/// #
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {
@ -269,14 +213,6 @@ pub trait GeneticNode {
/// #
/// //...
///
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// self.models.truncate(3);
@ -326,6 +262,16 @@ pub trait GeneticNode {
/// population_size: i64,
/// //...
/// }
/// #
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// # }
///
/// # impl Model {
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
@ -364,14 +310,6 @@ pub trait GeneticNode {
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
/// # .iter()
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
/// # .unwrap()
/// # .fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
/// # self.models.truncate(3);
@ -458,10 +396,6 @@ where
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.fit_score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
@ -477,7 +411,7 @@ where
///
/// # fn main() -> Result<(), Error> {
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
/// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0);
/// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0);
/// # Ok(())
/// # }
/// ```

View file

@ -5,78 +5,21 @@ pub mod genetic_node;
use crate::error::Error;
use crate::tree;
use genetic_node::GeneticNodeWrapper;
use genetic_node::{GeneticNodeWrapper, GeneticNode};
use file_linked::FileLinked;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::path;
use std::path::PathBuf;
use std::fs::File;
use std::io::ErrorKind;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
/// a node runs for.
///
/// # Examples
///
/// ```
/// # use gemla::bracket::*;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Constant iteration scaling ensures that every node is simulated 5 times.
/// bracket
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
/// .expect("Failed to set iteration scaling");
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
/// TODO
#[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)]
#[serde(tag = "enumType", content = "enumContent")]
pub enum IterationScaling {
@ -93,336 +36,87 @@ impl Default for IterationScaling {
}
}
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Bracket<T>
struct Bracket<T>
where
T: genetic_node::GeneticNode + Serialize,
T: GeneticNode + Serialize,
{
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
iteration_scaling: IterationScaling,
}
impl<T> Bracket<T>
where
T: genetic_node::GeneticNode
+ Default
+ DeserializeOwned
+ Serialize
+ Clone
+ PartialEq
+ Debug,
where T: GeneticNode + Serialize
{
/// Initializes a bracket of type `T` storing the contents to `file_path`
///
/// # Examples
/// ```
/// # use gemla::bracket::*;
/// # use gemla::btree;
/// # use gemla::tree;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// #[derive(Default, Deserialize, Serialize, Debug, Clone, PartialEq)]
/// struct TestState {
/// pub score: f64,
/// }
///
/// # impl FromStr for TestState {
/// # type Err = String;
/// #
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> {
/// # serde_json::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
/// # }
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// impl TestState {
/// fn new(score: f64) -> TestState {
/// TestState { score: score }
/// }
/// }
///
/// impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// fn initialize() -> Result<Box<Self>, Error> {
/// Ok(Box::new(TestState { score: 0.0 }))
/// }
///
/// //...
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// }
///
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
Ok(FileLinked::new(
Bracket {
tree: btree!(Some(GeneticNodeWrapper::new()?)),
iteration_scaling: IterationScaling::default(),
fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
Ok(())
}
fn process_tree(&mut self) -> Result<(), Error> {
Ok(())
}
}
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T>
where T: GeneticNode + Serialize + DeserializeOwned
{
data: FileLinked<Bracket<T>>
}
impl<T> Gemla<T>
where
T: GeneticNode
+ Serialize
+ DeserializeOwned
+ Default
{
pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
match File::open(path) {
Ok(file) => {
drop(file);
Ok(Gemla {
data:
if overwrite {
FileLinked::from_file(path)?
} else {
FileLinked::new(Bracket {
tree: btree!(None),
iteration_scaling: IterationScaling::default()
}, path)?
}
})
},
file_path,
)?)
}
/// Given a bracket object, configures it's [`IterationScaling`].
///
/// # Examples
/// ```
/// # use gemla::bracket::*;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Constant iteration scaling ensures that every node is simulated 5 times.
/// bracket
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
/// .expect("Failed to set iteration scaling");
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn iteration_scaling(&mut self, iteration_scaling: IterationScaling) -> &mut Self {
self.iteration_scaling = iteration_scaling;
self
}
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
// This additionally simulates and evaluates nodes in the branch as it is built.
fn create_new_branch(
&self,
height: u64,
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
if height == 1 {
let mut base_node = GeneticNodeWrapper::new()?;
base_node.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(Some(base_node)))
} else {
let left = self.create_new_branch(height - 1)?;
let right = self.create_new_branch(height - 1)?;
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
{
left.val.clone().unwrap()
} else {
right.val.clone().unwrap()
};
new_val.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(Some(new_val), left, right))
Err(error) if error.kind() == ErrorKind::NotFound => {
Ok(Gemla {
data: FileLinked::new(Bracket {
tree: btree!(None),
iteration_scaling: IterationScaling::default()
}, path)?
})
},
Err(error) => Err(Error::IO(error))
}
}
/// Runs one step of simulation on the current bracket which includes:
/// 1) Creating a new branch of the same height and performing the same steps for each subtree.
/// 2) Simulating the top node of the current branch.
/// 3) Comparing the top node of the current branch to the top node of the new branch.
/// 4) Takes the best performing node and makes it the root of the tree.
///
/// # Examples
/// ```
/// # use gemla::bracket::*;
/// # use gemla::error::Error;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::str::FromStr;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
/// #
/// # impl fmt::Display for TestState {
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
/// # write!(f, "{}", self.score)
/// # }
/// # }
/// #
/// # impl TestState {
/// # fn new(score: f64) -> TestState {
/// # TestState { score: score }
/// # }
/// # }
/// #
/// # impl genetic_node::GeneticNode for TestState {
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
/// # self.score += iterations as f64;
/// # Ok(())
/// # }
/// #
/// # fn get_fit_score(&self) -> f64 {
/// # self.score
/// # }
/// #
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(())
/// # }
/// #
/// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 }))
/// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # }
/// #
/// # fn main() {
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
/// .expect("Bracket failed to initialize");
///
/// // Running simulations 3 times
/// for _ in 0..3 {
/// bracket
/// .mutate(|b| drop(b.run_simulation_step()))
/// .expect("Failed to run step");
/// }
///
/// assert_eq!(bracket.readonly().tree.height(), 4);
///
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
let new_branch = self.create_new_branch(self.tree.height())?;
pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
self.data.mutate(|b| b.increase_height(steps as usize))??;
self.tree
.val
.clone()
.unwrap()
.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x,
})?;
self.data.mutate(|b| b.process_tree())??;
let new_val = if new_branch
.val
.clone()
.unwrap()
.data
.unwrap()
.get_fit_score()
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
{
new_branch.val.clone()
} else {
self.tree.val.clone()
};
self.tree = btree!(new_val, new_branch, self.tree.clone());
Ok(self)
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::bracket::*;
use crate::tree::*;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
@ -446,10 +140,6 @@ mod tests {
Ok(())
}
fn get_fit_score(&self) -> f64 {
self.score
}
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
Ok(())
}
@ -463,66 +153,11 @@ mod tests {
}
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.get_fit_score() > right.get_fit_score() {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {
right.clone()
}))
}
}
#[test]
fn test_new() {
let bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
.expect("Bracket failed to initialize");
assert_eq!(
bracket,
file_linked::FileLinked::new(
Bracket {
tree: Tree {
val: Some(GeneticNodeWrapper::new().unwrap()),
left: None,
right: None
},
iteration_scaling: IterationScaling::Constant(1)
},
path::PathBuf::from("./temp")
)
.unwrap()
);
std::fs::remove_file("./temp").expect("Unable to remove file");
}
#[test]
fn test_run() {
let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp2"))
.expect("Bracket failed to initialize");
bracket
.mutate(|b| drop(b.iteration_scaling(IterationScaling::Linear(2))))
.expect("Failed to set iteration scaling");
for _ in 0..3 {
bracket
.mutate(|b| drop(b.run_simulation_step()))
.expect("Failed to run step");
}
assert_eq!(bracket.readonly().tree.height(), 4);
assert_eq!(
bracket
.readonly()
.tree
.val
.clone()
.unwrap()
.data
.unwrap()
.score,
15.0
);
std::fs::remove_file("./temp2").expect("Unable to remove file");
}
}

View file

@ -5,6 +5,8 @@ pub enum Error {
#[error(transparent)]
FileLinked(file_linked::Error),
#[error(transparent)]
IO(std::io::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
@ -16,3 +18,9 @@ impl From<file_linked::Error> for Error {
}
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Error {
Error::IO(error)
}
}