Refactoring bracket interface
This commit is contained in:
parent
9081fb0b3c
commit
569a17f145
5 changed files with 137 additions and 599 deletions
|
@ -2,12 +2,10 @@
|
|||
|
||||
extern crate serde;
|
||||
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Context;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
@ -29,7 +27,7 @@ where
|
|||
T: Serialize,
|
||||
{
|
||||
val: T,
|
||||
path: path::PathBuf,
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl<T> FileLinked<T>
|
||||
|
@ -44,7 +42,7 @@ where
|
|||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -60,7 +58,7 @@ where
|
|||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
|
@ -82,7 +80,7 @@ where
|
|||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// #[derive(Deserialize, Serialize)]
|
||||
/// struct Test {
|
||||
|
@ -98,7 +96,7 @@ where
|
|||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
|
@ -108,20 +106,14 @@ where
|
|||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(val: T, path: path::PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let result = FileLinked { val, path };
|
||||
|
||||
pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let result = FileLinked { val, path: path.clone() };
|
||||
result.write_data()?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn write_data(&self) -> Result<(), Error> {
|
||||
let mut file = fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(&self.path)
|
||||
let mut file = File::create(&self.path)
|
||||
.with_context(|| format!("Unable to open path {}", self.path.display()))?;
|
||||
|
||||
write!(
|
||||
|
@ -143,7 +135,7 @@ where
|
|||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -159,7 +151,7 @@ where
|
|||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
|
@ -189,7 +181,7 @@ where
|
|||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -205,7 +197,7 @@ where
|
|||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
|
@ -245,7 +237,7 @@ where
|
|||
/// # use std::fs;
|
||||
/// # use std::fs::OpenOptions;
|
||||
/// # use std::io::Write;
|
||||
/// # use std::path;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -261,7 +253,7 @@ where
|
|||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let path = path::PathBuf::from("./temp");
|
||||
/// let path = PathBuf::from("./temp");
|
||||
///
|
||||
/// let mut file = OpenOptions::new()
|
||||
/// .write(true)
|
||||
|
@ -275,7 +267,7 @@ where
|
|||
///
|
||||
/// drop(file);
|
||||
///
|
||||
/// let mut linked_test = FileLinked::<Test>::from_file(path)
|
||||
/// let mut linked_test = FileLinked::<Test>::from_file(&path)
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, test.a);
|
||||
|
@ -287,27 +279,14 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn from_file(path: path::PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let metadata = path
|
||||
.metadata()
|
||||
.with_context(|| format!("Error obtaining metadata for {}", path.display()))?;
|
||||
pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let file = File::open(path)
|
||||
.with_context(|| format!("Unable to open file {}", path.display()))?;
|
||||
|
||||
if metadata.is_file() {
|
||||
let file = fs::OpenOptions::new()
|
||||
.read(true)
|
||||
.open(&path)
|
||||
.with_context(|| format!("Unable to open file {}", path.display()))?;
|
||||
let val = serde_json::from_reader(file)
|
||||
.with_context(|| String::from("Unable to parse value from file."))?;
|
||||
|
||||
let val = serde_json::from_reader(file)
|
||||
.with_context(|| String::from("Unable to parse value from file."))?;
|
||||
|
||||
Ok(FileLinked { val, path })
|
||||
} else {
|
||||
return Err(Error::IO(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
anyhow!("{} is not a file.", path.display()),
|
||||
)));
|
||||
}
|
||||
Ok(FileLinked { val, path: path.clone() })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,7 +298,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_mutate() -> Result<(), Error> {
|
||||
let list = vec![1, 2, 3, 4];
|
||||
let mut file_linked_list = FileLinked::new(list, path::PathBuf::from("test.txt"))?;
|
||||
let mut file_linked_list = FileLinked::new(list, &PathBuf::from("test.txt"))?;
|
||||
|
||||
assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]");
|
||||
|
||||
|
|
|
@ -40,14 +40,6 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_fit_score(&self) -> f64 {
|
||||
self.population
|
||||
.clone()
|
||||
.into_iter()
|
||||
.reduce(f64::max)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
|
||||
let mut v = self.population.clone();
|
||||
|
||||
|
@ -144,16 +136,6 @@ mod tests {
|
|||
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_fit_score() {
|
||||
let state = TestState {
|
||||
thread_rng: thread_rng(),
|
||||
population: vec![1.0, 1.0, 2.0, 3.0],
|
||||
};
|
||||
|
||||
assert_eq!(state.get_fit_score(), 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_scores_and_trim() {
|
||||
let mut state = TestState {
|
||||
|
|
|
@ -53,10 +53,6 @@ pub trait GeneticNode {
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
|
@ -72,7 +68,7 @@ pub trait GeneticNode {
|
|||
///
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// let node = Node::initialize()?;
|
||||
/// assert_eq!(node.get_fit_score(), 0.0);
|
||||
/// assert_eq!(node.fit_score, 0.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
|
@ -81,7 +77,7 @@ pub trait GeneticNode {
|
|||
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
|
||||
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
|
||||
///
|
||||
/// #Examples
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
|
@ -104,7 +100,17 @@ pub trait GeneticNode {
|
|||
/// # Ok(())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// # impl Node {
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||
|
@ -121,10 +127,6 @@ pub trait GeneticNode {
|
|||
/// }
|
||||
///
|
||||
/// //...
|
||||
///
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
|
@ -148,76 +150,6 @@ pub trait GeneticNode {
|
|||
/// ```
|
||||
fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
|
||||
|
||||
/// Returns a fit score associated with the nodes performance.
|
||||
/// This will be used by a bracket in order to determine the most successful child.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||
/// # use gemla::error::Error;
|
||||
/// #
|
||||
/// struct Model {
|
||||
/// pub fit_score: f64,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// struct Node {
|
||||
/// pub models: Vec<Model>,
|
||||
/// //...
|
||||
/// }
|
||||
///
|
||||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
||||
/// # //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// # }
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||
/// # }
|
||||
/// #
|
||||
/// # //...
|
||||
/// #
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
/// # for m in self.models.iter_mut()
|
||||
/// # {
|
||||
/// # m.fit(iterations)?;
|
||||
/// # }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// //...
|
||||
///
|
||||
/// fn get_fit_score(&self) -> f64 {
|
||||
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
|
||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// let mut node = Node::initialize()?;
|
||||
/// node.simulate(5)?;
|
||||
/// assert_eq!(node.get_fit_score(), 5.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn get_fit_score(&self) -> f64;
|
||||
|
||||
/// Used when scoring the nodes after simulating and should remove underperforming children.
|
||||
///
|
||||
/// # Examples
|
||||
|
@ -239,11 +171,23 @@ pub trait GeneticNode {
|
|||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
||||
/// # //...
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # self.fit_score += epochs as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// # }
|
||||
///
|
||||
/// #
|
||||
/// #
|
||||
/// # impl Node {
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||
/// # Ok(Box::new(Node {
|
||||
|
@ -269,14 +213,6 @@ pub trait GeneticNode {
|
|||
/// #
|
||||
/// //...
|
||||
///
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||
/// self.models.truncate(3);
|
||||
|
@ -326,6 +262,16 @@ pub trait GeneticNode {
|
|||
/// population_size: i64,
|
||||
/// //...
|
||||
/// }
|
||||
/// #
|
||||
/// # impl Node {
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// # }
|
||||
///
|
||||
/// # impl Model {
|
||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
||||
|
@ -364,14 +310,6 @@ pub trait GeneticNode {
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
/// # .iter()
|
||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||
/// # .unwrap()
|
||||
/// # .fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||
/// # self.models.truncate(3);
|
||||
|
@ -458,10 +396,6 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.fit_score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
|
@ -477,7 +411,7 @@ where
|
|||
///
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
|
||||
/// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0);
|
||||
/// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
|
|
|
@ -5,78 +5,21 @@ pub mod genetic_node;
|
|||
|
||||
use crate::error::Error;
|
||||
use crate::tree;
|
||||
use genetic_node::GeneticNodeWrapper;
|
||||
|
||||
use genetic_node::{GeneticNodeWrapper, GeneticNode};
|
||||
use file_linked::FileLinked;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use std::path;
|
||||
use std::path::PathBuf;
|
||||
use std::fs::File;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||
/// a node runs for.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use gemla::error::Error;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl TestState {
|
||||
/// # fn new(score: f64) -> TestState {
|
||||
/// # TestState { score: score }
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(left.clone()))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// // Constant iteration scaling ensures that every node is simulated 5 times.
|
||||
/// bracket
|
||||
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
|
||||
/// .expect("Failed to set iteration scaling");
|
||||
///
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
/// TODO
|
||||
#[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
pub enum IterationScaling {
|
||||
|
@ -93,336 +36,87 @@ impl Default for IterationScaling {
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
|
||||
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
|
||||
/// individuals.
|
||||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||
pub struct Bracket<T>
|
||||
struct Bracket<T>
|
||||
where
|
||||
T: genetic_node::GeneticNode + Serialize,
|
||||
T: GeneticNode + Serialize,
|
||||
{
|
||||
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
||||
iteration_scaling: IterationScaling,
|
||||
}
|
||||
|
||||
impl<T> Bracket<T>
|
||||
where
|
||||
T: genetic_node::GeneticNode
|
||||
+ Default
|
||||
+ DeserializeOwned
|
||||
+ Serialize
|
||||
+ Clone
|
||||
+ PartialEq
|
||||
+ Debug,
|
||||
where T: GeneticNode + Serialize
|
||||
{
|
||||
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use gemla::btree;
|
||||
/// # use gemla::tree;
|
||||
/// # use gemla::error::Error;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// #[derive(Default, Deserialize, Serialize, Debug, Clone, PartialEq)]
|
||||
/// struct TestState {
|
||||
/// pub score: f64,
|
||||
/// }
|
||||
///
|
||||
/// # impl FromStr for TestState {
|
||||
/// # type Err = String;
|
||||
/// #
|
||||
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> {
|
||||
/// # serde_json::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl fmt::Display for TestState {
|
||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
/// # write!(f, "{}", self.score)
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// impl TestState {
|
||||
/// fn new(score: f64) -> TestState {
|
||||
/// TestState { score: score }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// fn initialize() -> Result<Box<Self>, Error> {
|
||||
/// Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// }
|
||||
///
|
||||
/// //...
|
||||
/// #
|
||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(left.clone()))
|
||||
/// # }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
|
||||
Ok(FileLinked::new(
|
||||
Bracket {
|
||||
tree: btree!(Some(GeneticNodeWrapper::new()?)),
|
||||
iteration_scaling: IterationScaling::default(),
|
||||
fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_tree(&mut self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
|
||||
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
|
||||
/// individuals.
|
||||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
pub struct Gemla<T>
|
||||
where T: GeneticNode + Serialize + DeserializeOwned
|
||||
{
|
||||
data: FileLinked<Bracket<T>>
|
||||
}
|
||||
|
||||
impl<T> Gemla<T>
|
||||
where
|
||||
T: GeneticNode
|
||||
+ Serialize
|
||||
+ DeserializeOwned
|
||||
+ Default
|
||||
{
|
||||
pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
|
||||
match File::open(path) {
|
||||
Ok(file) => {
|
||||
drop(file);
|
||||
|
||||
Ok(Gemla {
|
||||
data:
|
||||
if overwrite {
|
||||
FileLinked::from_file(path)?
|
||||
} else {
|
||||
FileLinked::new(Bracket {
|
||||
tree: btree!(None),
|
||||
iteration_scaling: IterationScaling::default()
|
||||
}, path)?
|
||||
}
|
||||
})
|
||||
},
|
||||
file_path,
|
||||
)?)
|
||||
}
|
||||
|
||||
/// Given a bracket object, configures it's [`IterationScaling`].
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use gemla::error::Error;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl fmt::Display for TestState {
|
||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
/// # write!(f, "{}", self.score)
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl TestState {
|
||||
/// # fn new(score: f64) -> TestState {
|
||||
/// # TestState { score: score }
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(left.clone()))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// // Constant iteration scaling ensures that every node is simulated 5 times.
|
||||
/// bracket
|
||||
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
|
||||
/// .expect("Failed to set iteration scaling");
|
||||
///
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn iteration_scaling(&mut self, iteration_scaling: IterationScaling) -> &mut Self {
|
||||
self.iteration_scaling = iteration_scaling;
|
||||
self
|
||||
}
|
||||
|
||||
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
|
||||
// This additionally simulates and evaluates nodes in the branch as it is built.
|
||||
fn create_new_branch(
|
||||
&self,
|
||||
height: u64,
|
||||
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
|
||||
if height == 1 {
|
||||
let mut base_node = GeneticNodeWrapper::new()?;
|
||||
|
||||
base_node.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(Some(base_node)))
|
||||
} else {
|
||||
let left = self.create_new_branch(height - 1)?;
|
||||
let right = self.create_new_branch(height - 1)?;
|
||||
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
{
|
||||
left.val.clone().unwrap()
|
||||
} else {
|
||||
right.val.clone().unwrap()
|
||||
};
|
||||
|
||||
new_val.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(Some(new_val), left, right))
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => {
|
||||
Ok(Gemla {
|
||||
data: FileLinked::new(Bracket {
|
||||
tree: btree!(None),
|
||||
iteration_scaling: IterationScaling::default()
|
||||
}, path)?
|
||||
})
|
||||
},
|
||||
Err(error) => Err(Error::IO(error))
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs one step of simulation on the current bracket which includes:
|
||||
/// 1) Creating a new branch of the same height and performing the same steps for each subtree.
|
||||
/// 2) Simulating the top node of the current branch.
|
||||
/// 3) Comparing the top node of the current branch to the top node of the new branch.
|
||||
/// 4) Takes the best performing node and makes it the root of the tree.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// # use gemla::bracket::*;
|
||||
/// # use gemla::error::Error;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::str::FromStr;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl fmt::Display for TestState {
|
||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
/// # write!(f, "{}", self.score)
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl TestState {
|
||||
/// # fn new(score: f64) -> TestState {
|
||||
/// # TestState { score: score }
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # impl genetic_node::GeneticNode for TestState {
|
||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
/// # self.score += iterations as f64;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.score
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
||||
/// # Ok(Box::new(left.clone()))
|
||||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
||||
/// .expect("Bracket failed to initialize");
|
||||
///
|
||||
/// // Running simulations 3 times
|
||||
/// for _ in 0..3 {
|
||||
/// bracket
|
||||
/// .mutate(|b| drop(b.run_simulation_step()))
|
||||
/// .expect("Failed to run step");
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(bracket.readonly().tree.height(), 4);
|
||||
///
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
|
||||
let new_branch = self.create_new_branch(self.tree.height())?;
|
||||
pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||
self.data.mutate(|b| b.increase_height(steps as usize))??;
|
||||
|
||||
self.tree
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => (x * self.tree.height()),
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
self.data.mutate(|b| b.process_tree())??;
|
||||
|
||||
let new_val = if new_branch
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.data
|
||||
.unwrap()
|
||||
.get_fit_score()
|
||||
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
{
|
||||
new_branch.val.clone()
|
||||
} else {
|
||||
self.tree.val.clone()
|
||||
};
|
||||
|
||||
self.tree = btree!(new_val, new_branch, self.tree.clone());
|
||||
|
||||
Ok(self)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bracket::*;
|
||||
use crate::tree::*;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
|
@ -446,10 +140,6 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_fit_score(&self) -> f64 {
|
||||
self.score
|
||||
}
|
||||
|
||||
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
@ -463,66 +153,11 @@ mod tests {
|
|||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(if left.get_fit_score() > right.get_fit_score() {
|
||||
Ok(Box::new(if left.score > right.score {
|
||||
left.clone()
|
||||
} else {
|
||||
right.clone()
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
||||
.expect("Bracket failed to initialize");
|
||||
|
||||
assert_eq!(
|
||||
bracket,
|
||||
file_linked::FileLinked::new(
|
||||
Bracket {
|
||||
tree: Tree {
|
||||
val: Some(GeneticNodeWrapper::new().unwrap()),
|
||||
left: None,
|
||||
right: None
|
||||
},
|
||||
iteration_scaling: IterationScaling::Constant(1)
|
||||
},
|
||||
path::PathBuf::from("./temp")
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run() {
|
||||
let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp2"))
|
||||
.expect("Bracket failed to initialize");
|
||||
|
||||
bracket
|
||||
.mutate(|b| drop(b.iteration_scaling(IterationScaling::Linear(2))))
|
||||
.expect("Failed to set iteration scaling");
|
||||
for _ in 0..3 {
|
||||
bracket
|
||||
.mutate(|b| drop(b.run_simulation_step()))
|
||||
.expect("Failed to run step");
|
||||
}
|
||||
|
||||
assert_eq!(bracket.readonly().tree.height(), 4);
|
||||
assert_eq!(
|
||||
bracket
|
||||
.readonly()
|
||||
.tree
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.data
|
||||
.unwrap()
|
||||
.score,
|
||||
15.0
|
||||
);
|
||||
|
||||
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@ pub enum Error {
|
|||
#[error(transparent)]
|
||||
FileLinked(file_linked::Error),
|
||||
#[error(transparent)]
|
||||
IO(std::io::Error),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
|
@ -16,3 +18,9 @@ impl From<file_linked::Error> for Error {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Error {
|
||||
Error::IO(error)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue