Creating empty state for tree node
This commit is contained in:
parent
daae6c2705
commit
ad4bf7c4ca
2 changed files with 63 additions and 62 deletions
|
@ -11,7 +11,7 @@ use std::fmt;
|
|||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
/// [`GeneticNode`]: crate::bracket::genetic_node
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq)]
|
||||
#[serde(tag = "enumType", content = "enumContent")]
|
||||
pub enum GeneticState {
|
||||
/// The node and it's data have not finished initializing
|
||||
|
@ -396,14 +396,14 @@ pub trait GeneticNode {
|
|||
|
||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||
/// well as signal recovery. Transition states are given by [`GeneticState`]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||
pub struct GeneticNodeWrapper<T>
|
||||
where
|
||||
T: GeneticNode,
|
||||
{
|
||||
pub data: Option<T>,
|
||||
state: GeneticState,
|
||||
pub iteration: u32,
|
||||
pub iteration: u64,
|
||||
}
|
||||
|
||||
impl<T> GeneticNodeWrapper<T>
|
||||
|
@ -485,7 +485,7 @@ where
|
|||
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
|
||||
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
|
||||
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
|
||||
pub fn process_node(&mut self, iterations: u32) -> Result<(), Error> {
|
||||
pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> {
|
||||
// Looping through each state transition until the number of iterations have been reached.
|
||||
loop {
|
||||
match (self.state, &self.data) {
|
||||
|
@ -523,6 +523,8 @@ where
|
|||
.unwrap()
|
||||
.mutate()
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.iteration += 1;
|
||||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Finish, Some(_)) => {
|
||||
|
|
|
@ -5,10 +5,12 @@ pub mod genetic_node;
|
|||
|
||||
use crate::error::Error;
|
||||
use crate::tree;
|
||||
use genetic_node::GeneticNodeWrapper;
|
||||
|
||||
use file_linked::FileLinked;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use std::path;
|
||||
|
||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||
|
@ -25,7 +27,7 @@ use std::path;
|
|||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
|
@ -98,13 +100,19 @@ pub struct Bracket<T>
|
|||
where
|
||||
T: genetic_node::GeneticNode + Serialize,
|
||||
{
|
||||
pub tree: tree::Tree<T>,
|
||||
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
||||
iteration_scaling: IterationScaling,
|
||||
}
|
||||
|
||||
impl<T> Bracket<T>
|
||||
where
|
||||
T: genetic_node::GeneticNode + Default + DeserializeOwned + Serialize + Clone + PartialEq,
|
||||
T: genetic_node::GeneticNode
|
||||
+ Default
|
||||
+ DeserializeOwned
|
||||
+ Serialize
|
||||
+ Clone
|
||||
+ PartialEq
|
||||
+ Debug,
|
||||
{
|
||||
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
||||
///
|
||||
|
@ -180,7 +188,7 @@ where
|
|||
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
|
||||
Ok(FileLinked::new(
|
||||
Bracket {
|
||||
tree: btree!(*T::initialize()?),
|
||||
tree: btree!(Some(GeneticNodeWrapper::new()?)),
|
||||
iteration_scaling: IterationScaling::default(),
|
||||
},
|
||||
file_path,
|
||||
|
@ -199,7 +207,7 @@ where
|
|||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
|
@ -258,31 +266,36 @@ where
|
|||
|
||||
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
|
||||
// This additionally simulates and evaluates nodes in the branch as it is built.
|
||||
fn create_new_branch(&self, height: u64) -> Result<tree::Tree<T>, Error> {
|
||||
fn create_new_branch(
|
||||
&self,
|
||||
height: u64,
|
||||
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
|
||||
if height == 1 {
|
||||
let mut base_node = btree!(*T::initialize()?);
|
||||
let mut base_node = GeneticNodeWrapper::new()?;
|
||||
|
||||
base_node.val.simulate(match self.iteration_scaling {
|
||||
base_node.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(base_node.val))
|
||||
Ok(btree!(Some(base_node)))
|
||||
} else {
|
||||
let left = self.create_new_branch(height - 1)?;
|
||||
let right = self.create_new_branch(height - 1)?;
|
||||
let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() {
|
||||
left.val.clone()
|
||||
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
{
|
||||
left.val.clone().unwrap()
|
||||
} else {
|
||||
right.val.clone()
|
||||
right.val.clone().unwrap()
|
||||
};
|
||||
|
||||
new_val.simulate(match self.iteration_scaling {
|
||||
new_val.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => x * height,
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
Ok(btree!(new_val, left, right))
|
||||
Ok(btree!(Some(new_val), left, right))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -302,7 +315,7 @@ where
|
|||
/// # use std::string::ToString;
|
||||
/// # use std::path;
|
||||
/// #
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||
/// # struct TestState {
|
||||
/// # pub score: f64,
|
||||
/// # }
|
||||
|
@ -361,12 +374,24 @@ where
|
|||
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
|
||||
let new_branch = self.create_new_branch(self.tree.height())?;
|
||||
|
||||
self.tree.val.simulate(match self.iteration_scaling {
|
||||
self.tree
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.process_node(match self.iteration_scaling {
|
||||
IterationScaling::Linear(x) => (x * self.tree.height()),
|
||||
IterationScaling::Constant(x) => x,
|
||||
})?;
|
||||
|
||||
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() {
|
||||
let new_val = if new_branch
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.data
|
||||
.unwrap()
|
||||
.get_fit_score()
|
||||
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||
{
|
||||
new_branch.val.clone()
|
||||
} else {
|
||||
self.tree.val.clone()
|
||||
|
@ -432,7 +457,7 @@ mod tests {
|
|||
file_linked::FileLinked::new(
|
||||
Bracket {
|
||||
tree: Tree {
|
||||
val: TestState { score: 0.0 },
|
||||
val: Some(GeneticNodeWrapper::new().unwrap()),
|
||||
left: None,
|
||||
right: None
|
||||
},
|
||||
|
@ -460,44 +485,18 @@ mod tests {
|
|||
.expect("Failed to run step");
|
||||
}
|
||||
|
||||
assert_eq!(bracket.readonly().tree.height(), 4);
|
||||
assert_eq!(
|
||||
bracket,
|
||||
file_linked::FileLinked::new(
|
||||
Bracket {
|
||||
iteration_scaling: IterationScaling::Linear(2),
|
||||
tree: btree!(
|
||||
TestState { score: 12.0 },
|
||||
btree!(
|
||||
TestState { score: 12.0 },
|
||||
btree!(
|
||||
TestState { score: 6.0 },
|
||||
btree!(TestState { score: 2.0 }),
|
||||
btree!(TestState { score: 2.0 })
|
||||
),
|
||||
btree!(
|
||||
TestState { score: 6.0 },
|
||||
btree!(TestState { score: 2.0 }),
|
||||
btree!(TestState { score: 2.0 })
|
||||
)
|
||||
),
|
||||
btree!(
|
||||
TestState { score: 12.0 },
|
||||
btree!(
|
||||
TestState { score: 6.0 },
|
||||
btree!(TestState { score: 2.0 }),
|
||||
btree!(TestState { score: 2.0 })
|
||||
),
|
||||
btree!(
|
||||
TestState { score: 6.0 },
|
||||
btree!(TestState { score: 2.0 }),
|
||||
btree!(TestState { score: 2.0 })
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
path::PathBuf::from("./temp2")
|
||||
)
|
||||
bracket
|
||||
.readonly()
|
||||
.tree
|
||||
.val
|
||||
.clone()
|
||||
.unwrap()
|
||||
.data
|
||||
.unwrap()
|
||||
.score,
|
||||
15.0
|
||||
);
|
||||
|
||||
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
||||
|
|
Loading…
Add table
Reference in a new issue