Creating empty state for tree node
This commit is contained in:
parent
daae6c2705
commit
ad4bf7c4ca
2 changed files with 63 additions and 62 deletions
|
@ -11,7 +11,7 @@ use std::fmt;
|
||||||
/// An enum used to control the state of a [`GeneticNode`]
|
/// An enum used to control the state of a [`GeneticNode`]
|
||||||
///
|
///
|
||||||
/// [`GeneticNode`]: crate::bracket::genetic_node
|
/// [`GeneticNode`]: crate::bracket::genetic_node
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq)]
|
||||||
#[serde(tag = "enumType", content = "enumContent")]
|
#[serde(tag = "enumType", content = "enumContent")]
|
||||||
pub enum GeneticState {
|
pub enum GeneticState {
|
||||||
/// The node and it's data have not finished initializing
|
/// The node and it's data have not finished initializing
|
||||||
|
@ -396,14 +396,14 @@ pub trait GeneticNode {
|
||||||
|
|
||||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||||
/// well as signal recovery. Transition states are given by [`GeneticState`]
|
/// well as signal recovery. Transition states are given by [`GeneticState`]
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||||
pub struct GeneticNodeWrapper<T>
|
pub struct GeneticNodeWrapper<T>
|
||||||
where
|
where
|
||||||
T: GeneticNode,
|
T: GeneticNode,
|
||||||
{
|
{
|
||||||
pub data: Option<T>,
|
pub data: Option<T>,
|
||||||
state: GeneticState,
|
state: GeneticState,
|
||||||
pub iteration: u32,
|
pub iteration: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> GeneticNodeWrapper<T>
|
impl<T> GeneticNodeWrapper<T>
|
||||||
|
@ -485,7 +485,7 @@ where
|
||||||
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
|
/// [`simulate`]: crate::bracket::genetic_node::GeneticNode#tymethod.simulate
|
||||||
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
|
/// [`calculate_scores_and_trim`]: crate::bracket::genetic_node::GeneticNode#tymethod.calculate_scores_and_trim
|
||||||
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
|
/// [`mutate`]: crate::bracket::genetic_node::GeneticNode#tymethod.mutate
|
||||||
pub fn process_node(&mut self, iterations: u32) -> Result<(), Error> {
|
pub fn process_node(&mut self, iterations: u64) -> Result<(), Error> {
|
||||||
// Looping through each state transition until the number of iterations have been reached.
|
// Looping through each state transition until the number of iterations have been reached.
|
||||||
loop {
|
loop {
|
||||||
match (self.state, &self.data) {
|
match (self.state, &self.data) {
|
||||||
|
@ -523,6 +523,8 @@ where
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.mutate()
|
.mutate()
|
||||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||||
|
|
||||||
|
self.iteration += 1;
|
||||||
self.state = GeneticState::Simulate;
|
self.state = GeneticState::Simulate;
|
||||||
}
|
}
|
||||||
(GeneticState::Finish, Some(_)) => {
|
(GeneticState::Finish, Some(_)) => {
|
||||||
|
|
|
@ -5,10 +5,12 @@ pub mod genetic_node;
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::tree;
|
use crate::tree;
|
||||||
|
use genetic_node::GeneticNodeWrapper;
|
||||||
|
|
||||||
use file_linked::FileLinked;
|
use file_linked::FileLinked;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::Debug;
|
||||||
use std::path;
|
use std::path;
|
||||||
|
|
||||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||||
|
@ -25,7 +27,7 @@ use std::path;
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||||
/// # struct TestState {
|
/// # struct TestState {
|
||||||
/// # pub score: f64,
|
/// # pub score: f64,
|
||||||
/// # }
|
/// # }
|
||||||
|
@ -98,13 +100,19 @@ pub struct Bracket<T>
|
||||||
where
|
where
|
||||||
T: genetic_node::GeneticNode + Serialize,
|
T: genetic_node::GeneticNode + Serialize,
|
||||||
{
|
{
|
||||||
pub tree: tree::Tree<T>,
|
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
||||||
iteration_scaling: IterationScaling,
|
iteration_scaling: IterationScaling,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Bracket<T>
|
impl<T> Bracket<T>
|
||||||
where
|
where
|
||||||
T: genetic_node::GeneticNode + Default + DeserializeOwned + Serialize + Clone + PartialEq,
|
T: genetic_node::GeneticNode
|
||||||
|
+ Default
|
||||||
|
+ DeserializeOwned
|
||||||
|
+ Serialize
|
||||||
|
+ Clone
|
||||||
|
+ PartialEq
|
||||||
|
+ Debug,
|
||||||
{
|
{
|
||||||
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
/// Initializes a bracket of type `T` storing the contents to `file_path`
|
||||||
///
|
///
|
||||||
|
@ -180,7 +188,7 @@ where
|
||||||
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
|
pub fn initialize(file_path: path::PathBuf) -> Result<FileLinked<Self>, Error> {
|
||||||
Ok(FileLinked::new(
|
Ok(FileLinked::new(
|
||||||
Bracket {
|
Bracket {
|
||||||
tree: btree!(*T::initialize()?),
|
tree: btree!(Some(GeneticNodeWrapper::new()?)),
|
||||||
iteration_scaling: IterationScaling::default(),
|
iteration_scaling: IterationScaling::default(),
|
||||||
},
|
},
|
||||||
file_path,
|
file_path,
|
||||||
|
@ -199,7 +207,7 @@ where
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||||
/// # struct TestState {
|
/// # struct TestState {
|
||||||
/// # pub score: f64,
|
/// # pub score: f64,
|
||||||
/// # }
|
/// # }
|
||||||
|
@ -258,31 +266,36 @@ where
|
||||||
|
|
||||||
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
|
// Creates a balanced tree with the given `height` that will be used as a branch of the primary tree.
|
||||||
// This additionally simulates and evaluates nodes in the branch as it is built.
|
// This additionally simulates and evaluates nodes in the branch as it is built.
|
||||||
fn create_new_branch(&self, height: u64) -> Result<tree::Tree<T>, Error> {
|
fn create_new_branch(
|
||||||
|
&self,
|
||||||
|
height: u64,
|
||||||
|
) -> Result<tree::Tree<Option<GeneticNodeWrapper<T>>>, Error> {
|
||||||
if height == 1 {
|
if height == 1 {
|
||||||
let mut base_node = btree!(*T::initialize()?);
|
let mut base_node = GeneticNodeWrapper::new()?;
|
||||||
|
|
||||||
base_node.val.simulate(match self.iteration_scaling {
|
base_node.process_node(match self.iteration_scaling {
|
||||||
IterationScaling::Linear(x) => x * height,
|
IterationScaling::Linear(x) => x * height,
|
||||||
IterationScaling::Constant(x) => x,
|
IterationScaling::Constant(x) => x,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(btree!(base_node.val))
|
Ok(btree!(Some(base_node)))
|
||||||
} else {
|
} else {
|
||||||
let left = self.create_new_branch(height - 1)?;
|
let left = self.create_new_branch(height - 1)?;
|
||||||
let right = self.create_new_branch(height - 1)?;
|
let right = self.create_new_branch(height - 1)?;
|
||||||
let mut new_val = if left.val.get_fit_score() >= right.val.get_fit_score() {
|
let mut new_val = if left.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||||
left.val.clone()
|
>= right.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||||
|
{
|
||||||
|
left.val.clone().unwrap()
|
||||||
} else {
|
} else {
|
||||||
right.val.clone()
|
right.val.clone().unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
new_val.simulate(match self.iteration_scaling {
|
new_val.process_node(match self.iteration_scaling {
|
||||||
IterationScaling::Linear(x) => x * height,
|
IterationScaling::Linear(x) => x * height,
|
||||||
IterationScaling::Constant(x) => x,
|
IterationScaling::Constant(x) => x,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(btree!(new_val, left, right))
|
Ok(btree!(Some(new_val), left, right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,7 +315,7 @@ where
|
||||||
/// # use std::string::ToString;
|
/// # use std::string::ToString;
|
||||||
/// # use std::path;
|
/// # use std::path;
|
||||||
/// #
|
/// #
|
||||||
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq)]
|
/// # #[derive(Default, Deserialize, Serialize, Clone, PartialEq, Debug)]
|
||||||
/// # struct TestState {
|
/// # struct TestState {
|
||||||
/// # pub score: f64,
|
/// # pub score: f64,
|
||||||
/// # }
|
/// # }
|
||||||
|
@ -361,12 +374,24 @@ where
|
||||||
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
|
pub fn run_simulation_step(&mut self) -> Result<&mut Self, Error> {
|
||||||
let new_branch = self.create_new_branch(self.tree.height())?;
|
let new_branch = self.create_new_branch(self.tree.height())?;
|
||||||
|
|
||||||
self.tree.val.simulate(match self.iteration_scaling {
|
self.tree
|
||||||
IterationScaling::Linear(x) => (x * self.tree.height()),
|
.val
|
||||||
IterationScaling::Constant(x) => x,
|
.clone()
|
||||||
})?;
|
.unwrap()
|
||||||
|
.process_node(match self.iteration_scaling {
|
||||||
|
IterationScaling::Linear(x) => (x * self.tree.height()),
|
||||||
|
IterationScaling::Constant(x) => x,
|
||||||
|
})?;
|
||||||
|
|
||||||
let new_val = if new_branch.val.get_fit_score() >= self.tree.val.get_fit_score() {
|
let new_val = if new_branch
|
||||||
|
.val
|
||||||
|
.clone()
|
||||||
|
.unwrap()
|
||||||
|
.data
|
||||||
|
.unwrap()
|
||||||
|
.get_fit_score()
|
||||||
|
>= self.tree.val.clone().unwrap().data.unwrap().get_fit_score()
|
||||||
|
{
|
||||||
new_branch.val.clone()
|
new_branch.val.clone()
|
||||||
} else {
|
} else {
|
||||||
self.tree.val.clone()
|
self.tree.val.clone()
|
||||||
|
@ -432,7 +457,7 @@ mod tests {
|
||||||
file_linked::FileLinked::new(
|
file_linked::FileLinked::new(
|
||||||
Bracket {
|
Bracket {
|
||||||
tree: Tree {
|
tree: Tree {
|
||||||
val: TestState { score: 0.0 },
|
val: Some(GeneticNodeWrapper::new().unwrap()),
|
||||||
left: None,
|
left: None,
|
||||||
right: None
|
right: None
|
||||||
},
|
},
|
||||||
|
@ -460,44 +485,18 @@ mod tests {
|
||||||
.expect("Failed to run step");
|
.expect("Failed to run step");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert_eq!(bracket.readonly().tree.height(), 4);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
bracket,
|
bracket
|
||||||
file_linked::FileLinked::new(
|
.readonly()
|
||||||
Bracket {
|
.tree
|
||||||
iteration_scaling: IterationScaling::Linear(2),
|
.val
|
||||||
tree: btree!(
|
.clone()
|
||||||
TestState { score: 12.0 },
|
.unwrap()
|
||||||
btree!(
|
.data
|
||||||
TestState { score: 12.0 },
|
.unwrap()
|
||||||
btree!(
|
.score,
|
||||||
TestState { score: 6.0 },
|
15.0
|
||||||
btree!(TestState { score: 2.0 }),
|
|
||||||
btree!(TestState { score: 2.0 })
|
|
||||||
),
|
|
||||||
btree!(
|
|
||||||
TestState { score: 6.0 },
|
|
||||||
btree!(TestState { score: 2.0 }),
|
|
||||||
btree!(TestState { score: 2.0 })
|
|
||||||
)
|
|
||||||
),
|
|
||||||
btree!(
|
|
||||||
TestState { score: 12.0 },
|
|
||||||
btree!(
|
|
||||||
TestState { score: 6.0 },
|
|
||||||
btree!(TestState { score: 2.0 }),
|
|
||||||
btree!(TestState { score: 2.0 })
|
|
||||||
),
|
|
||||||
btree!(
|
|
||||||
TestState { score: 6.0 },
|
|
||||||
btree!(TestState { score: 2.0 }),
|
|
||||||
btree!(TestState { score: 2.0 })
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
path::PathBuf::from("./temp2")
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
std::fs::remove_file("./temp2").expect("Unable to remove file");
|
||||||
|
|
Loading…
Add table
Reference in a new issue