Creating empty state for tree node

This commit is contained in:
vandomej 2021-10-03 01:54:56 -07:00
parent daae6c2705
commit ad4bf7c4ca
2 changed files with 63 additions and 62 deletions

View file

@ -11,7 +11,7 @@ use std::fmt;
/// An enum used to control the state of a [`GeneticNode`] /// An enum used to control the state of a [`GeneticNode`]
/// ///
/// [`GeneticNode`]: crate::bracket::genetic_node /// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Clone, Debug, Serialize, Deserialize, Copy)] #[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq)]
#[serde(tag = "enumType", content = "enumContent")] #[serde(tag = "enumType", content = "enumContent")]
pub enum GeneticState { pub enum GeneticState {
/// The node and it's data have not finished initializing /// The node and it's data have not finished initializing
@ -396,14 +396,14 @@ pub trait GeneticNode {
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
/// well as signal recovery. Transition states are given by [`GeneticState`] /// well as signal recovery. Transition states are given by [`GeneticState`]
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct GeneticNodeWrapper<T> pub struct GeneticNodeWrapper<T>
where where
T: GeneticNode, T: GeneticNode,
{ {
pub data: Option<T>, pub data: Option<T>,
state: GeneticState, state: GeneticState,
pub iteration: u32, pub iteration: u64,
} }
impl<T> GeneticNodeWrapper<T> impl<T> GeneticNodeWrapper<T>
@ -485,7 +485,7 @@ where
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate /// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim /// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate /// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
pub fn process_node(&mut self, iterations: u32) -> Result<(), Error> { pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> {
// Looping through each state transition until the number of iterations have been reached. // Looping through each state transition until the number of iterations have been reached.
loop { loop {
match (self.state, &self.data) { match (self.state, &self.data) {
@ -523,6 +523,8 @@ where
.unwrap() .unwrap()
.mutate() .mutate()
.with_context(|| format!("Error mutating node: {:?}", self))?; .with_context(|| format!("Error mutating node: {:?}", self))?;
self.iteration += 1;
self.state = GeneticState::Simulate; self.state = GeneticState::Simulate;
} }
(GeneticState::Finish, Some(_)) => { (GeneticState::Finish, Some(_)) => {

View file

@ -5,10 +5,12 @@ pub mod genetic_node;
use crate::error::Error; use crate::error::Error;
use crate::tree; use crate::tree;
use genetic_node::GeneticNodeWrapper;
use file_linked::FileLinked; use file_linked::FileLinked;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::path; use std::path;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that /// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
@ -25,7 +27,7 @@ use std::path;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path;
/// # /// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState { /// # struct TestState {
/// # pub score: f64, /// # pub score: f64,
/// # } /// # }
@ -98,13 +100,19 @@ pub struct Bracket<T>
where where
T: genetic_node::GeneticNode + Serialize, T: genetic_node::GeneticNode + Serialize,
{ {
pub tree: tree::Tree<T>, pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
iteration_scaling: IterationScaling, iteration_scaling: IterationScaling,
} }
impl<T> Bracket<T> impl<T> Bracket<T>
where where
T: genetic_node::GeneticNode + Default + DeserializeOwned + Serialize + Clone + PartialEq, T: genetic_node::GeneticNode
+ Default
+ DeserializeOwned
+ Serialize
+ Clone
+ PartialEq
+ Debug,
{ {
/// Initializes a bracket of type `T` storing the contents to `file_path` /// Initializes a bracket of type `T` storing the contents to `file_path`
/// ///
@ -180,7 +188,7 @@ where
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> { pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
Ok(FileLinked::new( Ok(FileLinked::new(
Bracket { Bracket {
tree: btree!(*T::initialize()?), tree: btree!(Some(GeneticNodeWrapper::new()?)),
iteration_scaling: IterationScaling::default(), iteration_scaling: IterationScaling::default(),
}, },
file_path, file_path,
@ -199,7 +207,7 @@ where
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path;
/// # /// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState { /// # struct TestState {
/// # pub score: f64, /// # pub score: f64,
/// # } /// # }
@ -258,31 +266,36 @@ where
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree. // Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
// This additionally simulates and evaluates nodes in the branch as it is built. // This additionally simulates and evaluates nodes in the branch as it is built.
fn create_new_branch(&self, height: u64) -> Result<tree::Tree<T>, Error> { fn create_new_branch(
&self,
height: u64,
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
if height == 1 { if height == 1 {
let mut base_node = btree!(*T::initialize()?); let mut base_node = GeneticNodeWrapper::new()?;
base_node.val.simulate(match self.iteration_scaling { base_node.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height, IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x, IterationScaling::Constant(x) => x,
})?; })?;
Ok(btree!(base_node.val)) Ok(btree!(Some(base_node)))
} else { } else {
let left = self.create_new_branch(height - 1)?; let left = self.create_new_branch(height - 1)?;
let right = self.create_new_branch(height - 1)?; let right = self.create_new_branch(height - 1)?;
let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() { let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
left.val.clone() >= right.val.clone().unwrap().data.unwrap().get_fit_score()
{
left.val.clone().unwrap()
} else { } else {
right.val.clone() right.val.clone().unwrap()
}; };
new_val.simulate(match self.iteration_scaling { new_val.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height, IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x, IterationScaling::Constant(x) => x,
})?; })?;
Ok(btree!(new_val, left, right)) Ok(btree!(Some(new_val), left, right))
} }
} }
@ -302,7 +315,7 @@ where
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path; /// # use std::path;
/// # /// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)] /// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState { /// # struct TestState {
/// # pub score: f64, /// # pub score: f64,
/// # } /// # }
@ -361,12 +374,24 @@ where
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> { pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
let new_branch = self.create_new_branch(self.tree.height())?; let new_branch = self.create_new_branch(self.tree.height())?;
self.tree.val.simulate(match self.iteration_scaling { self.tree
.val
.clone()
.unwrap()
.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => (x * self.tree.height()), IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x, IterationScaling::Constant(x) => x,
})?; })?;
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() { let new_val = if new_branch
.val
.clone()
.unwrap()
.data
.unwrap()
.get_fit_score()
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
{
new_branch.val.clone() new_branch.val.clone()
} else { } else {
self.tree.val.clone() self.tree.val.clone()
@ -432,7 +457,7 @@ mod tests {
file_linked::FileLinked::new( file_linked::FileLinked::new(
Bracket { Bracket {
tree: Tree { tree: Tree {
val: TestState { score: 0.0 }, val: Some(GeneticNodeWrapper::new().unwrap()),
left: None, left: None,
right: None right: None
}, },
@ -460,44 +485,18 @@ mod tests {
.expect("Failed to run step"); .expect("Failed to run step");
} }
assert_eq!(bracket.readonly().tree.height(), 4);
assert_eq!( assert_eq!(
bracket, bracket
file_linked::FileLinked::new( .readonly()
Bracket { .tree
iteration_scaling: IterationScaling::Linear(2), .val
tree: btree!( .clone()
TestState { score: 12.0 },
btree!(
TestState { score: 12.0 },
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
),
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
)
),
btree!(
TestState { score: 12.0 },
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
),
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
)
)
)
},
path::PathBuf::from("./temp2")
)
.unwrap() .unwrap()
.data
.unwrap()
.score,
15.0
); );
std::fs::remove_file("./temp2").expect("Unable to remove file"); std::fs::remove_file("./temp2").expect("Unable to remove file");