Refactoring bracket interface
This commit is contained in:
parent
9081fb0b3c
commit
569a17f145
5 changed files with 137 additions and 599 deletions
|
@ -2,12 +2,10 @@
|
||||||
|
|
||||||
extern crate serde;
|
extern crate serde;
|
||||||
|
|
||||||
use std::fs;
|
use std::fs::File;
|
||||||
use std::io;
|
|
||||||
use std::io::prelude::*;
|
use std::io::prelude::*;
|
||||||
use std::path;
|
use std::path::PathBuf;
|
||||||
|
use anyhow::Context;
|
||||||
use anyhow::{anyhow, Context};
|
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -29,7 +27,7 @@ where
|
||||||
T: Serialize,
|
T: Serialize,
|
||||||
{
|
{
|
||||||
val: T,
|
val: T,
|
||||||
path: path::PathBuf,
|
path: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> FileLinked<T>
|
impl<T> FileLinked<T>
|
||||||
|
@ -44,7 +42,7 @@ where
|
||||||
/// # use serde::{Deserialize, Serialize};
|
/// # use serde::{Deserialize, Serialize};
|
||||||
/// # use std::fmt;
|
/// # use std::fmt;
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path::PathBuf;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Deserialize, Serialize)]
|
/// # #[derive(Deserialize, Serialize)]
|
||||||
/// # struct Test {
|
/// # struct Test {
|
||||||
|
@ -60,7 +58,7 @@ where
|
||||||
/// c: 3.0
|
/// c: 3.0
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||||
/// .expect("Unable to create file linked object");
|
/// .expect("Unable to create file linked object");
|
||||||
///
|
///
|
||||||
/// assert_eq!(linked_test.readonly().a, 1);
|
/// assert_eq!(linked_test.readonly().a, 1);
|
||||||
|
@ -82,7 +80,7 @@ where
|
||||||
/// # use serde::{Deserialize, Serialize};
|
/// # use serde::{Deserialize, Serialize};
|
||||||
/// # use std::fmt;
|
/// # use std::fmt;
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path::PathBuf;
|
||||||
/// #
|
/// #
|
||||||
/// #[derive(Deserialize, Serialize)]
|
/// #[derive(Deserialize, Serialize)]
|
||||||
/// struct Test {
|
/// struct Test {
|
||||||
|
@ -98,7 +96,7 @@ where
|
||||||
/// c: 3.0
|
/// c: 3.0
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// let linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||||
/// .expect("Unable to create file linked object");
|
/// .expect("Unable to create file linked object");
|
||||||
///
|
///
|
||||||
/// assert_eq!(linked_test.readonly().a, 1);
|
/// assert_eq!(linked_test.readonly().a, 1);
|
||||||
|
@ -108,20 +106,14 @@ where
|
||||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(val: T, path: path::PathBuf) -> Result<FileLinked<T>, Error> {
|
pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||||
let result = FileLinked { val, path };
|
let result = FileLinked { val, path: path.clone() };
|
||||||
|
|
||||||
result.write_data()?;
|
result.write_data()?;
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_data(&self) -> Result<(), Error> {
|
fn write_data(&self) -> Result<(), Error> {
|
||||||
let mut file = fs::OpenOptions::new()
|
let mut file = File::create(&self.path)
|
||||||
.write(true)
|
|
||||||
.create(true)
|
|
||||||
.truncate(true)
|
|
||||||
.open(&self.path)
|
|
||||||
.with_context(|| format!("Unable to open path {}", self.path.display()))?;
|
.with_context(|| format!("Unable to open path {}", self.path.display()))?;
|
||||||
|
|
||||||
write!(
|
write!(
|
||||||
|
@ -143,7 +135,7 @@ where
|
||||||
/// # use serde::{Deserialize, Serialize};
|
/// # use serde::{Deserialize, Serialize};
|
||||||
/// # use std::fmt;
|
/// # use std::fmt;
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path::PathBuf;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Deserialize, Serialize)]
|
/// # #[derive(Deserialize, Serialize)]
|
||||||
/// # struct Test {
|
/// # struct Test {
|
||||||
|
@ -159,7 +151,7 @@ where
|
||||||
/// c: 0.0
|
/// c: 0.0
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||||
/// .expect("Unable to create file linked object");
|
/// .expect("Unable to create file linked object");
|
||||||
///
|
///
|
||||||
/// assert_eq!(linked_test.readonly().a, 1);
|
/// assert_eq!(linked_test.readonly().a, 1);
|
||||||
|
@ -189,7 +181,7 @@ where
|
||||||
/// # use serde::{Deserialize, Serialize};
|
/// # use serde::{Deserialize, Serialize};
|
||||||
/// # use std::fmt;
|
/// # use std::fmt;
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path::PathBuf;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Deserialize, Serialize)]
|
/// # #[derive(Deserialize, Serialize)]
|
||||||
/// # struct Test {
|
/// # struct Test {
|
||||||
|
@ -205,7 +197,7 @@ where
|
||||||
/// c: 0.0
|
/// c: 0.0
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// let mut linked_test = FileLinked::new(test, path::PathBuf::from("./temp"))
|
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||||
/// .expect("Unable to create file linked object");
|
/// .expect("Unable to create file linked object");
|
||||||
///
|
///
|
||||||
/// assert_eq!(linked_test.readonly().a, 1);
|
/// assert_eq!(linked_test.readonly().a, 1);
|
||||||
|
@ -245,7 +237,7 @@ where
|
||||||
/// # use std::fs;
|
/// # use std::fs;
|
||||||
/// # use std::fs::OpenOptions;
|
/// # use std::fs::OpenOptions;
|
||||||
/// # use std::io::Write;
|
/// # use std::io::Write;
|
||||||
/// # use std::path;
|
/// # use std::path::PathBuf;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Deserialize, Serialize)]
|
/// # #[derive(Deserialize, Serialize)]
|
||||||
/// # struct Test {
|
/// # struct Test {
|
||||||
|
@ -261,7 +253,7 @@ where
|
||||||
/// c: 3.0
|
/// c: 3.0
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// let path = path::PathBuf::from("./temp");
|
/// let path = PathBuf::from("./temp");
|
||||||
///
|
///
|
||||||
/// let mut file = OpenOptions::new()
|
/// let mut file = OpenOptions::new()
|
||||||
/// .write(true)
|
/// .write(true)
|
||||||
|
@ -275,7 +267,7 @@ where
|
||||||
///
|
///
|
||||||
/// drop(file);
|
/// drop(file);
|
||||||
///
|
///
|
||||||
/// let mut linked_test = FileLinked::<Test>::from_file(path)
|
/// let mut linked_test = FileLinked::<Test>::from_file(&path)
|
||||||
/// .expect("Unable to create file linked object");
|
/// .expect("Unable to create file linked object");
|
||||||
///
|
///
|
||||||
/// assert_eq!(linked_test.readonly().a, test.a);
|
/// assert_eq!(linked_test.readonly().a, test.a);
|
||||||
|
@ -287,27 +279,14 @@ where
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn from_file(path: path::PathBuf) -> Result<FileLinked<T>, Error> {
|
pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||||
let metadata = path
|
let file = File::open(path)
|
||||||
.metadata()
|
.with_context(|| format!("Unable to open file {}", path.display()))?;
|
||||||
.with_context(|| format!("Error obtaining metadata for {}", path.display()))?;
|
|
||||||
|
|
||||||
if metadata.is_file() {
|
let val = serde_json::from_reader(file)
|
||||||
let file = fs::OpenOptions::new()
|
.with_context(|| String::from("Unable to parse value from file."))?;
|
||||||
.read(true)
|
|
||||||
.open(&path)
|
|
||||||
.with_context(|| format!("Unable to open file {}", path.display()))?;
|
|
||||||
|
|
||||||
let val = serde_json::from_reader(file)
|
Ok(FileLinked { val, path: path.clone() })
|
||||||
.with_context(|| String::from("Unable to parse value from file."))?;
|
|
||||||
|
|
||||||
Ok(FileLinked { val, path })
|
|
||||||
} else {
|
|
||||||
return Err(Error::IO(io::Error::new(
|
|
||||||
io::ErrorKind::Other,
|
|
||||||
anyhow!("{} is not a file.", path.display()),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,7 +298,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mutate() -> Result<(), Error> {
|
fn test_mutate() -> Result<(), Error> {
|
||||||
let list = vec![1, 2, 3, 4];
|
let list = vec![1, 2, 3, 4];
|
||||||
let mut file_linked_list = FileLinked::new(list, path::PathBuf::from("test.txt"))?;
|
let mut file_linked_list = FileLinked::new(list, &PathBuf::from("test.txt"))?;
|
||||||
|
|
||||||
assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]");
|
assert_eq!(format!("{:?}", file_linked_list.readonly()), "[1, 2, 3, 4]");
|
||||||
|
|
||||||
|
|
|
@ -40,14 +40,6 @@ impl GeneticNode for TestState {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_fit_score(&self) -> f64 {
|
|
||||||
self.population
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.reduce(f64::max)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
|
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
|
||||||
let mut v = self.population.clone();
|
let mut v = self.population.clone();
|
||||||
|
|
||||||
|
@ -144,16 +136,6 @@ mod tests {
|
||||||
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0))
|
.all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_get_fit_score() {
|
|
||||||
let state = TestState {
|
|
||||||
thread_rng: thread_rng(),
|
|
||||||
population: vec![1.0, 1.0, 2.0, 3.0],
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(state.get_fit_score(), 3.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_calculate_scores_and_trim() {
|
fn test_calculate_scores_and_trim() {
|
||||||
let mut state = TestState {
|
let mut state = TestState {
|
||||||
|
|
|
@ -53,10 +53,6 @@ pub trait GeneticNode {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// #
|
/// #
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.fit_score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
|
@ -72,7 +68,7 @@ pub trait GeneticNode {
|
||||||
///
|
///
|
||||||
/// # fn main() -> Result<(), Error> {
|
/// # fn main() -> Result<(), Error> {
|
||||||
/// let node = Node::initialize()?;
|
/// let node = Node::initialize()?;
|
||||||
/// assert_eq!(node.get_fit_score(), 0.0);
|
/// assert_eq!(node.fit_score, 0.0);
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
|
@ -81,7 +77,7 @@ pub trait GeneticNode {
|
||||||
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
|
/// Runs a simulation on the state object for the given number of `iterations` in order to guage it's fitness.
|
||||||
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
|
/// This will be called for every node in a bracket before evaluating it's fitness against other nodes.
|
||||||
///
|
///
|
||||||
/// #Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
/// # use gemla::bracket::genetic_node::GeneticNode;
|
||||||
|
@ -104,7 +100,17 @@ pub trait GeneticNode {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
|
/// # impl Node {
|
||||||
|
/// # fn get_fit_score(&self) -> f64 {
|
||||||
|
/// # self.models
|
||||||
|
/// # .iter()
|
||||||
|
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||||
|
/// # .unwrap()
|
||||||
|
/// # .fit_score
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
|
/// #
|
||||||
/// impl GeneticNode for Node {
|
/// impl GeneticNode for Node {
|
||||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
||||||
|
@ -121,10 +127,6 @@ pub trait GeneticNode {
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// //...
|
/// //...
|
||||||
///
|
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
|
||||||
/// # }
|
|
||||||
/// #
|
/// #
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
|
@ -148,76 +150,6 @@ pub trait GeneticNode {
|
||||||
/// ```
|
/// ```
|
||||||
fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
|
fn simulate(&mut self, iterations: u64) -> Result<(), Error>;
|
||||||
|
|
||||||
/// Returns a fit score associated with the nodes performance.
|
|
||||||
/// This will be used by a bracket in order to determine the most successful child.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
/// ```
|
|
||||||
/// # use gemla::bracket::genetic_node::GeneticNode;
|
|
||||||
/// # use gemla::error::Error;
|
|
||||||
/// #
|
|
||||||
/// struct Model {
|
|
||||||
/// pub fit_score: f64,
|
|
||||||
/// //...
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// struct Node {
|
|
||||||
/// pub models: Vec<Model>,
|
|
||||||
/// //...
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// # impl Model {
|
|
||||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
|
||||||
/// # //...
|
|
||||||
/// # self.fit_score += epochs as f64;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
///
|
|
||||||
/// impl GeneticNode for Node {
|
|
||||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
|
||||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # //...
|
|
||||||
/// #
|
|
||||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
|
||||||
/// # for m in self.models.iter_mut()
|
|
||||||
/// # {
|
|
||||||
/// # m.fit(iterations)?;
|
|
||||||
/// # }
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// //...
|
|
||||||
///
|
|
||||||
/// fn get_fit_score(&self) -> f64 {
|
|
||||||
/// self.models.iter().max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap()).unwrap().fit_score
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// //...
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
|
|
||||||
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
|
|
||||||
/// # }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// # fn main() -> Result<(), Error> {
|
|
||||||
/// let mut node = Node::initialize()?;
|
|
||||||
/// node.simulate(5)?;
|
|
||||||
/// assert_eq!(node.get_fit_score(), 5.0);
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
fn get_fit_score(&self) -> f64;
|
|
||||||
|
|
||||||
/// Used when scoring the nodes after simulating and should remove underperforming children.
|
/// Used when scoring the nodes after simulating and should remove underperforming children.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
|
@ -239,11 +171,23 @@ pub trait GeneticNode {
|
||||||
/// # impl Model {
|
/// # impl Model {
|
||||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
||||||
/// # //...
|
/// # //...
|
||||||
/// # self.fit_score += epochs as f64;
|
/// # self.fit_score += epochs as f64;
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// # }
|
/// # }
|
||||||
///
|
/// #
|
||||||
|
/// #
|
||||||
|
/// # impl Node {
|
||||||
|
/// # fn get_fit_score(&self) -> f64 {
|
||||||
|
/// # self.models
|
||||||
|
/// # .iter()
|
||||||
|
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||||
|
/// # .unwrap()
|
||||||
|
/// # .fit_score
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
|
/// #
|
||||||
|
///
|
||||||
/// impl GeneticNode for Node {
|
/// impl GeneticNode for Node {
|
||||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||||
/// # Ok(Box::new(Node {
|
/// # Ok(Box::new(Node {
|
||||||
|
@ -269,14 +213,6 @@ pub trait GeneticNode {
|
||||||
/// #
|
/// #
|
||||||
/// //...
|
/// //...
|
||||||
///
|
///
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.models
|
|
||||||
/// # .iter()
|
|
||||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
|
||||||
/// # .unwrap()
|
|
||||||
/// # .fit_score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
/// fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
/// self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||||
/// self.models.truncate(3);
|
/// self.models.truncate(3);
|
||||||
|
@ -326,6 +262,16 @@ pub trait GeneticNode {
|
||||||
/// population_size: i64,
|
/// population_size: i64,
|
||||||
/// //...
|
/// //...
|
||||||
/// }
|
/// }
|
||||||
|
/// #
|
||||||
|
/// # impl Node {
|
||||||
|
/// # fn get_fit_score(&self) -> f64 {
|
||||||
|
/// # self.models
|
||||||
|
/// # .iter()
|
||||||
|
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
||||||
|
/// # .unwrap()
|
||||||
|
/// # .fit_score
|
||||||
|
/// # }
|
||||||
|
/// # }
|
||||||
///
|
///
|
||||||
/// # impl Model {
|
/// # impl Model {
|
||||||
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
/// # fn fit(&mut self, epochs: u64) -> Result<(), Error> {
|
||||||
|
@ -364,14 +310,6 @@ pub trait GeneticNode {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// #
|
/// #
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.models
|
|
||||||
/// # .iter()
|
|
||||||
/// # .max_by(|m1, m2| m1.fit_score.partial_cmp(&m2.fit_score).unwrap())
|
|
||||||
/// # .unwrap()
|
|
||||||
/// # .fit_score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
/// # self.models.sort_by(|a, b| a.fit_score.partial_cmp(&b.fit_score).unwrap().reverse());
|
||||||
/// # self.models.truncate(3);
|
/// # self.models.truncate(3);
|
||||||
|
@ -458,10 +396,6 @@ where
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// #
|
/// #
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.fit_score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
|
@ -477,7 +411,7 @@ where
|
||||||
///
|
///
|
||||||
/// # fn main() -> Result<(), Error> {
|
/// # fn main() -> Result<(), Error> {
|
||||||
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
|
/// let mut wrapped_node = GeneticNodeWrapper::<Node>::new()?;
|
||||||
/// assert_eq!(wrapped_node.data.unwrap().get_fit_score(), 0.0);
|
/// assert_eq!(wrapped_node.data.unwrap().fit_score, 0.0);
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
|
|
|
@ -5,78 +5,21 @@ pub mod genetic_node;
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::tree;
|
use crate::tree;
|
||||||
use genetic_node::GeneticNodeWrapper;
|
use genetic_node::{GeneticNodeWrapper, GeneticNode};
|
||||||
|
|
||||||
use file_linked::FileLinked;
|
use file_linked::FileLinked;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::path;
|
use std::path::PathBuf;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
|
||||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||||
/// a node runs for.
|
/// a node runs for.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```
|
/// TODO
|
||||||
/// # use gemla::bracket::*;
|
|
||||||
/// # use gemla::error::Error;
|
|
||||||
/// # use serde::{Deserialize, Serialize};
|
|
||||||
/// # use std::fmt;
|
|
||||||
/// # use std::str::FromStr;
|
|
||||||
/// # use std::string::ToString;
|
|
||||||
/// # use std::path;
|
|
||||||
/// #
|
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
|
||||||
/// # struct TestState {
|
|
||||||
/// # pub score: f64,
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl TestState {
|
|
||||||
/// # fn new(score: f64) -> TestState {
|
|
||||||
/// # TestState { score: score }
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl genetic_node::GeneticNode for TestState {
|
|
||||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
|
||||||
/// # self.score += iterations as f64;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(left.clone()))
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn main() {
|
|
||||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
|
||||||
/// .expect("Bracket failed to initialize");
|
|
||||||
///
|
|
||||||
/// // Constant iteration scaling ensures that every node is simulated 5 times.
|
|
||||||
/// bracket
|
|
||||||
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
|
|
||||||
/// .expect("Failed to set iteration scaling");
|
|
||||||
///
|
|
||||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
#[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)]
|
#[derive(Clone, Serialize, Deserialize, Copy, Debug, PartialEq)]
|
||||||
#[serde(tag = "enumType", content = "enumContent")]
|
#[serde(tag = "enumType", content = "enumContent")]
|
||||||
pub enum IterationScaling {
|
pub enum IterationScaling {
|
||||||
|
@ -93,336 +36,87 @@ impl Default for IterationScaling {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
|
||||||
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
|
|
||||||
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
|
|
||||||
/// individuals.
|
|
||||||
///
|
|
||||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||||
pub struct Bracket<T>
|
struct Bracket<T>
|
||||||
where
|
where
|
||||||
T: genetic_node::GeneticNode + Serialize,
|
T: GeneticNode + Serialize,
|
||||||
{
|
{
|
||||||
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
||||||
iteration_scaling: IterationScaling,
|
iteration_scaling: IterationScaling,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Bracket<T>
|
impl<T> Bracket<T>
|
||||||
where
|
where T: GeneticNode + Serialize
|
||||||
T: genetic_node::GeneticNode
|
|
||||||
+ Default
|
|
||||||
+ DeserializeOwned
|
|
||||||
+ Serialize
|
|
||||||
+ Clone
|
|
||||||
+ PartialEq
|
|
||||||
+ Debug,
|
|
||||||
{
|
{
|
||||||
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
|
||||||
///
|
Ok(())
|
||||||
/// # Examples
|
}
|
||||||
/// ```
|
|
||||||
/// # use gemla::bracket::*;
|
fn process_tree(&mut self) -> Result<(), Error> {
|
||||||
/// # use gemla::btree;
|
Ok(())
|
||||||
/// # use gemla::tree;
|
}
|
||||||
/// # use gemla::error::Error;
|
}
|
||||||
/// # use serde::{Deserialize, Serialize};
|
|
||||||
/// # use std::fmt;
|
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
|
||||||
/// # use std::str::FromStr;
|
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
|
||||||
/// # use std::string::ToString;
|
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
|
||||||
/// # use std::path;
|
/// individuals.
|
||||||
/// #
|
///
|
||||||
/// #[derive(Default, Deserialize, Serialize, Debug, Clone, PartialEq)]
|
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||||
/// struct TestState {
|
pub struct Gemla<T>
|
||||||
/// pub score: f64,
|
where T: GeneticNode + Serialize + DeserializeOwned
|
||||||
/// }
|
{
|
||||||
///
|
data: FileLinked<Bracket<T>>
|
||||||
/// # impl FromStr for TestState {
|
}
|
||||||
/// # type Err = String;
|
|
||||||
/// #
|
impl<T> Gemla<T>
|
||||||
/// # fn from_str(s: &str) -> Result<TestState, Self::Err> {
|
where
|
||||||
/// # serde_json::from_str(s).map_err(|_| format!("Unable to parse string {}", s))
|
T: GeneticNode
|
||||||
/// # }
|
+ Serialize
|
||||||
/// # }
|
+ DeserializeOwned
|
||||||
/// #
|
+ Default
|
||||||
/// # impl fmt::Display for TestState {
|
{
|
||||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
|
||||||
/// # write!(f, "{}", self.score)
|
match File::open(path) {
|
||||||
/// # }
|
Ok(file) => {
|
||||||
/// # }
|
drop(file);
|
||||||
/// #
|
|
||||||
/// impl TestState {
|
Ok(Gemla {
|
||||||
/// fn new(score: f64) -> TestState {
|
data:
|
||||||
/// TestState { score: score }
|
if overwrite {
|
||||||
/// }
|
FileLinked::from_file(path)?
|
||||||
/// }
|
} else {
|
||||||
///
|
FileLinked::new(Bracket {
|
||||||
/// impl genetic_node::GeneticNode for TestState {
|
tree: btree!(None),
|
||||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
iteration_scaling: IterationScaling::default()
|
||||||
/// # self.score += iterations as f64;
|
}, path)?
|
||||||
/// # Ok(())
|
}
|
||||||
/// # }
|
})
|
||||||
/// #
|
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// fn initialize() -> Result<Box<Self>, Error> {
|
|
||||||
/// Ok(Box::new(TestState { score: 0.0 }))
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// //...
|
|
||||||
/// #
|
|
||||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(left.clone()))
|
|
||||||
/// # }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// # fn main() {
|
|
||||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
|
||||||
/// .expect("Bracket failed to initialize");
|
|
||||||
///
|
|
||||||
/// std::fs::remove_file("./temp").expect("Unable to remove file");
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
|
|
||||||
Ok(FileLinked::new(
|
|
||||||
Bracket {
|
|
||||||
tree: btree!(Some(GeneticNodeWrapper::new()?)),
|
|
||||||
iteration_scaling: IterationScaling::default(),
|
|
||||||
},
|
},
|
||||||
file_path,
|
Err(error) if error.kind() == ErrorKind::NotFound => {
|
||||||
)?)
|
Ok(Gemla {
|
||||||
}
|
data: FileLinked::new(Bracket {
|
||||||
|
tree: btree!(None),
|
||||||
/// Given a bracket object, configures it's [`IterationScaling`].
|
iteration_scaling: IterationScaling::default()
|
||||||
///
|
}, path)?
|
||||||
/// # Examples
|
})
|
||||||
/// ```
|
},
|
||||||
/// # use gemla::bracket::*;
|
Err(error) => Err(Error::IO(error))
|
||||||
/// # use gemla::error::Error;
|
|
||||||
/// # use serde::{Deserialize, Serialize};
|
|
||||||
/// # use std::fmt;
|
|
||||||
/// # use std::str::FromStr;
|
|
||||||
/// # use std::string::ToString;
|
|
||||||
/// # use std::path;
|
|
||||||
/// #
|
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
|
||||||
/// # struct TestState {
|
|
||||||
/// # pub score: f64,
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl fmt::Display for TestState {
|
|
||||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
/// # write!(f, "{}", self.score)
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl TestState {
|
|
||||||
/// # fn new(score: f64) -> TestState {
|
|
||||||
/// # TestState { score: score }
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl genetic_node::GeneticNode for TestState {
|
|
||||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
|
||||||
/// # self.score += iterations as f64;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(left.clone()))
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn main() {
|
|
||||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
|
||||||
/// .expect("Bracket failed to initialize");
|
|
||||||
///
|
|
||||||
/// // Constant iteration scaling ensures that every node is simulated 5 times.
|
|
||||||
/// bracket
|
|
||||||
/// .mutate(|b| drop(b.iteration_scaling(IterationScaling::Constant(5))))
|
|
||||||
/// .expect("Failed to set iteration scaling");
|
|
||||||
///
|
|
||||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
pub fn iteration_scaling(&mut self, iteration_scaling: IterationScaling) -> &mut Self {
|
|
||||||
self.iteration_scaling = iteration_scaling;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
|
|
||||||
// This additionally simulates and evaluates nodes in the branch as it is built.
|
|
||||||
fn create_new_branch(
|
|
||||||
&self,
|
|
||||||
height: u64,
|
|
||||||
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
|
|
||||||
if height == 1 {
|
|
||||||
let mut base_node = GeneticNodeWrapper::new()?;
|
|
||||||
|
|
||||||
base_node.process_node(match self.iteration_scaling {
|
|
||||||
IterationScaling::Linear(x) => x * height,
|
|
||||||
IterationScaling::Constant(x) => x,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(btree!(Some(base_node)))
|
|
||||||
} else {
|
|
||||||
let left = self.create_new_branch(height - 1)?;
|
|
||||||
let right = self.create_new_branch(height - 1)?;
|
|
||||||
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
|
|
||||||
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
|
|
||||||
{
|
|
||||||
left.val.clone().unwrap()
|
|
||||||
} else {
|
|
||||||
right.val.clone().unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
new_val.process_node(match self.iteration_scaling {
|
|
||||||
IterationScaling::Linear(x) => x * height,
|
|
||||||
IterationScaling::Constant(x) => x,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(btree!(Some(new_val), left, right))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs one step of simulation on the current bracket which includes:
|
pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||||
/// 1) Creating a new branch of the same height and performing the same steps for each subtree.
|
self.data.mutate(|b| b.increase_height(steps as usize))??;
|
||||||
/// 2) Simulating the top node of the current branch.
|
|
||||||
/// 3) Comparing the top node of the current branch to the top node of the new branch.
|
|
||||||
/// 4) Takes the best performing node and makes it the root of the tree.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
/// ```
|
|
||||||
/// # use gemla::bracket::*;
|
|
||||||
/// # use gemla::error::Error;
|
|
||||||
/// # use serde::{Deserialize, Serialize};
|
|
||||||
/// # use std::fmt;
|
|
||||||
/// # use std::str::FromStr;
|
|
||||||
/// # use std::string::ToString;
|
|
||||||
/// # use std::path;
|
|
||||||
/// #
|
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
|
||||||
/// # struct TestState {
|
|
||||||
/// # pub score: f64,
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl fmt::Display for TestState {
|
|
||||||
/// # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
/// # write!(f, "{}", self.score)
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl TestState {
|
|
||||||
/// # fn new(score: f64) -> TestState {
|
|
||||||
/// # TestState { score: score }
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # impl genetic_node::GeneticNode for TestState {
|
|
||||||
/// # fn simulate(&mut self, iterations: u64) -> Result<(), Error> {
|
|
||||||
/// # self.score += iterations as f64;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn get_fit_score(&self) -> f64 {
|
|
||||||
/// # self.score
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn mutate(&mut self) -> Result<(), Error> {
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn initialize() -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(TestState { score: 0.0 }))
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
|
|
||||||
/// # Ok(Box::new(left.clone()))
|
|
||||||
/// # }
|
|
||||||
/// # }
|
|
||||||
/// #
|
|
||||||
/// # fn main() {
|
|
||||||
/// let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
|
||||||
/// .expect("Bracket failed to initialize");
|
|
||||||
///
|
|
||||||
/// // Running simulations 3 times
|
|
||||||
/// for _ in 0..3 {
|
|
||||||
/// bracket
|
|
||||||
/// .mutate(|b| drop(b.run_simulation_step()))
|
|
||||||
/// .expect("Failed to run step");
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// assert_eq!(bracket.readonly().tree.height(), 4);
|
|
||||||
///
|
|
||||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
|
|
||||||
let new_branch = self.create_new_branch(self.tree.height())?;
|
|
||||||
|
|
||||||
self.tree
|
self.data.mutate(|b| b.process_tree())??;
|
||||||
.val
|
|
||||||
.clone()
|
|
||||||
.unwrap()
|
|
||||||
.process_node(match self.iteration_scaling {
|
|
||||||
IterationScaling::Linear(x) => (x * self.tree.height()),
|
|
||||||
IterationScaling::Constant(x) => x,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let new_val = if new_branch
|
Ok(())
|
||||||
.val
|
|
||||||
.clone()
|
|
||||||
.unwrap()
|
|
||||||
.data
|
|
||||||
.unwrap()
|
|
||||||
.get_fit_score()
|
|
||||||
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
|
|
||||||
{
|
|
||||||
new_branch.val.clone()
|
|
||||||
} else {
|
|
||||||
self.tree.val.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
self.tree = btree!(new_val, new_branch, self.tree.clone());
|
|
||||||
|
|
||||||
Ok(self)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::bracket::*;
|
use crate::bracket::*;
|
||||||
use crate::tree::*;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
@ -446,10 +140,6 @@ mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_fit_score(&self) -> f64 {
|
|
||||||
self.score
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
fn calculate_scores_and_trim(&mut self) -> Result<(), Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -463,66 +153,11 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||||
Ok(Box::new(if left.get_fit_score() > right.get_fit_score() {
|
Ok(Box::new(if left.score > right.score {
|
||||||
left.clone()
|
left.clone()
|
||||||
} else {
|
} else {
|
||||||
right.clone()
|
right.clone()
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_new() {
|
|
||||||
let bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp"))
|
|
||||||
.expect("Bracket failed to initialize");
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
bracket,
|
|
||||||
file_linked::FileLinked::new(
|
|
||||||
Bracket {
|
|
||||||
tree: Tree {
|
|
||||||
val: Some(GeneticNodeWrapper::new().unwrap()),
|
|
||||||
left: None,
|
|
||||||
right: None
|
|
||||||
},
|
|
||||||
iteration_scaling: IterationScaling::Constant(1)
|
|
||||||
},
|
|
||||||
path::PathBuf::from("./temp")
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
std::fs::remove_file("./temp").expect("Unable to remove file");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_run() {
|
|
||||||
let mut bracket = Bracket::<TestState>::initialize(path::PathBuf::from("./temp2"))
|
|
||||||
.expect("Bracket failed to initialize");
|
|
||||||
|
|
||||||
bracket
|
|
||||||
.mutate(|b| drop(b.iteration_scaling(IterationScaling::Linear(2))))
|
|
||||||
.expect("Failed to set iteration scaling");
|
|
||||||
for _ in 0..3 {
|
|
||||||
bracket
|
|
||||||
.mutate(|b| drop(b.run_simulation_step()))
|
|
||||||
.expect("Failed to run step");
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(bracket.readonly().tree.height(), 4);
|
|
||||||
assert_eq!(
|
|
||||||
bracket
|
|
||||||
.readonly()
|
|
||||||
.tree
|
|
||||||
.val
|
|
||||||
.clone()
|
|
||||||
.unwrap()
|
|
||||||
.data
|
|
||||||
.unwrap()
|
|
||||||
.score,
|
|
||||||
15.0
|
|
||||||
);
|
|
||||||
|
|
||||||
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,8 @@ pub enum Error {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
FileLinked(file_linked::Error),
|
FileLinked(file_linked::Error),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
IO(std::io::Error),
|
||||||
|
#[error(transparent)]
|
||||||
Other(#[from] anyhow::Error),
|
Other(#[from] anyhow::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,3 +18,9 @@ impl From<file_linked::Error> for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for Error {
|
||||||
|
fn from(error: std::io::Error) -> Error {
|
||||||
|
Error::IO(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue