Creating empty state for tree node

This commit is contained in:
vandomej 2021-10-03 01:54:56 -07:00
parent daae6c2705
commit ad4bf7c4ca
2 changed files with 63 additions and 62 deletions

View file

@ -11,7 +11,7 @@ use std::fmt;
/// An enum used to control the state of a [`GeneticNode`]
///
/// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq)]
#[serde(tag = "enumType", content = "enumContent")]
pub enum GeneticState {
/// The node and it's data have not finished initializing
@ -396,14 +396,14 @@ pub trait GeneticNode {
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
/// well as signal recovery. Transition states are given by [`GeneticState`]
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct GeneticNodeWrapper<T>
where
T: GeneticNode,
{
pub data: Option<T>,
state: GeneticState,
pub iteration: u32,
pub iteration: u64,
}
impl<T> GeneticNodeWrapper<T>
@ -485,7 +485,7 @@ where
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
pub fn process_node(&mut self, iterations: u32) -> Result<(), Error> {
pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> {
// Looping through each state transition until the number of iterations have been reached.
loop {
match (self.state, &self.data) {
@ -523,6 +523,8 @@ where
.unwrap()
.mutate()
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.iteration += 1;
self.state = GeneticState::Simulate;
}
(GeneticState::Finish, Some(_)) => {

View file

@ -5,10 +5,12 @@ pub mod genetic_node;
use crate::error::Error;
use crate::tree;
use genetic_node::GeneticNodeWrapper;
use file_linked::FileLinked;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::path;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
@ -25,7 +27,7 @@ use std::path;
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
@ -98,13 +100,19 @@ pub struct Bracket<T>
where
T: genetic_node::GeneticNode + Serialize,
{
pub tree: tree::Tree<T>,
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
iteration_scaling: IterationScaling,
}
impl<T> Bracket<T>
where
T: genetic_node::GeneticNode + Default + DeserializeOwned + Serialize + Clone + PartialEq,
T: genetic_node::GeneticNode
+ Default
+ DeserializeOwned
+ Serialize
+ Clone
+ PartialEq
+ Debug,
{
/// Initializes a bracket of type `T` storing the contents to `file_path`
///
@ -180,7 +188,7 @@ where
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
Ok(FileLinked::new(
Bracket {
tree: btree!(*T::initialize()?),
tree: btree!(Some(GeneticNodeWrapper::new()?)),
iteration_scaling: IterationScaling::default(),
},
file_path,
@ -199,7 +207,7 @@ where
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
@ -258,31 +266,36 @@ where
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
// This additionally simulates and evaluates nodes in the branch as it is built.
fn create_new_branch(&self, height: u64) -> Result<tree::Tree<T>, Error> {
fn create_new_branch(
&self,
height: u64,
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
if height == 1 {
let mut base_node = btree!(*T::initialize()?);
let mut base_node = GeneticNodeWrapper::new()?;
base_node.val.simulate(match self.iteration_scaling {
base_node.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(base_node.val))
Ok(btree!(Some(base_node)))
} else {
let left = self.create_new_branch(height - 1)?;
let right = self.create_new_branch(height - 1)?;
let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() {
left.val.clone()
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
{
left.val.clone().unwrap()
} else {
right.val.clone()
right.val.clone().unwrap()
};
new_val.simulate(match self.iteration_scaling {
new_val.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => x * height,
IterationScaling::Constant(x) => x,
})?;
Ok(btree!(new_val, left, right))
Ok(btree!(Some(new_val), left, right))
}
}
@ -302,7 +315,7 @@ where
/// # use std::string::ToString;
/// # use std::path;
/// #
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
/// # struct TestState {
/// # pub score: f64,
/// # }
@ -361,12 +374,24 @@ where
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
let new_branch = self.create_new_branch(self.tree.height())?;
self.tree.val.simulate(match self.iteration_scaling {
IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x,
})?;
self.tree
.val
.clone()
.unwrap()
.process_node(match self.iteration_scaling {
IterationScaling::Linear(x) => (x * self.tree.height()),
IterationScaling::Constant(x) => x,
})?;
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() {
let new_val = if new_branch
.val
.clone()
.unwrap()
.data
.unwrap()
.get_fit_score()
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
{
new_branch.val.clone()
} else {
self.tree.val.clone()
@ -432,7 +457,7 @@ mod tests {
file_linked::FileLinked::new(
Bracket {
tree: Tree {
val: TestState { score: 0.0 },
val: Some(GeneticNodeWrapper::new().unwrap()),
left: None,
right: None
},
@ -460,44 +485,18 @@ mod tests {
.expect("Failed to run step");
}
assert_eq!(bracket.readonly().tree.height(), 4);
assert_eq!(
bracket,
file_linked::FileLinked::new(
Bracket {
iteration_scaling: IterationScaling::Linear(2),
tree: btree!(
TestState { score: 12.0 },
btree!(
TestState { score: 12.0 },
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
),
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
)
),
btree!(
TestState { score: 12.0 },
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
),
btree!(
TestState { score: 6.0 },
btree!(TestState { score: 2.0 }),
btree!(TestState { score: 2.0 })
)
)
)
},
path::PathBuf::from("./temp2")
)
.unwrap()
bracket
.readonly()
.tree
.val
.clone()
.unwrap()
.data
.unwrap()
.score,
15.0
);
std::fs::remove_file("./temp2").expect("Unable to remove file");