Filling out function to process bracket tree
This commit is contained in:
parent
569a17f145
commit
e5188ec02f
4 changed files with 122 additions and 48 deletions
|
@ -2,12 +2,12 @@
|
|||
|
||||
extern crate serde;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Context;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -106,8 +106,11 @@ where
|
|||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let result = FileLinked { val, path: path.clone() };
|
||||
pub fn new(val: T, path: &Path) -> Result<FileLinked<T>, Error> {
|
||||
let result = FileLinked {
|
||||
val,
|
||||
path: path.to_path_buf(),
|
||||
};
|
||||
result.write_data()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -279,14 +282,17 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
|
||||
let file = File::open(path)
|
||||
.with_context(|| format!("Unable to open file {}", path.display()))?;
|
||||
pub fn from_file(path: &Path) -> Result<FileLinked<T>, Error> {
|
||||
let file =
|
||||
File::open(path).with_context(|| format!("Unable to open file {}", path.display()))?;
|
||||
|
||||
let val = serde_json::from_reader(file)
|
||||
.with_context(|| String::from("Unable to parse value from file."))?;
|
||||
|
||||
Ok(FileLinked { val, path: path.clone() })
|
||||
Ok(FileLinked {
|
||||
val,
|
||||
path: path.to_path_buf(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ pub trait GeneticNode {
|
|||
/// # Ok(())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// # impl Node {
|
||||
/// # fn get_fit_score(&self) -> f64 {
|
||||
/// # self.models
|
||||
|
@ -187,7 +187,7 @@ pub trait GeneticNode {
|
|||
/// # }
|
||||
/// # }
|
||||
/// #
|
||||
///
|
||||
///
|
||||
/// impl GeneticNode for Node {
|
||||
/// # fn initialize() -> Result<Box<Node>, Error> {
|
||||
/// # Ok(Box::new(Node {
|
||||
|
@ -429,6 +429,18 @@ where
|
|||
Ok(node)
|
||||
}
|
||||
|
||||
pub fn from(data: T) -> Result<Self, Error> {
|
||||
let mut node = GeneticNodeWrapper {
|
||||
data: Some(data),
|
||||
state: GeneticState::Initialize,
|
||||
iteration: 0,
|
||||
};
|
||||
|
||||
node.state = GeneticState::Simulate;
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
|
||||
/// Performs state transitions on the [`GeneticNode`] wrapped by the [`GeneticNodeWrapper`].
|
||||
/// Will loop through the node training and scoring process for the given number of `iterations`.
|
||||
///
|
||||
|
|
|
@ -4,15 +4,17 @@
|
|||
pub mod genetic_node;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::tree;
|
||||
use genetic_node::{GeneticNodeWrapper, GeneticNode};
|
||||
use crate::tree::Tree;
|
||||
use anyhow::anyhow;
|
||||
use file_linked::FileLinked;
|
||||
use genetic_node::{GeneticNode, GeneticNodeWrapper};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use std::path::PathBuf;
|
||||
use std::fs::File;
|
||||
use std::io::ErrorKind;
|
||||
use std::mem::replace;
|
||||
use std::path::Path;
|
||||
|
||||
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
|
||||
/// a node runs for.
|
||||
|
@ -41,18 +43,71 @@ struct Bracket<T>
|
|||
where
|
||||
T: GeneticNode + Serialize,
|
||||
{
|
||||
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
|
||||
tree: Option<Tree<Option<GeneticNodeWrapper<T>>>>,
|
||||
iteration_scaling: IterationScaling,
|
||||
}
|
||||
|
||||
impl<T> Bracket<T>
|
||||
where T: GeneticNode + Serialize
|
||||
where
|
||||
T: GeneticNode + Serialize + Debug,
|
||||
{
|
||||
fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
|
||||
fn build_empty_tree(size: usize) -> Tree<Option<GeneticNodeWrapper<T>>> {
|
||||
if size <= 1 {
|
||||
btree!(None)
|
||||
} else {
|
||||
btree!(
|
||||
None,
|
||||
Bracket::build_empty_tree(size - 1),
|
||||
Bracket::build_empty_tree(size - 1)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn increase_height(&mut self, amount: u64) {
|
||||
for _ in 0..amount {
|
||||
let height = self.tree.as_ref().unwrap().height();
|
||||
let tree = replace(&mut self.tree, None);
|
||||
drop(replace(
|
||||
&mut self.tree,
|
||||
Some(btree!(
|
||||
None,
|
||||
tree.unwrap(),
|
||||
Bracket::build_empty_tree(height as usize)
|
||||
)),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn process_tree(tree: &mut Tree<Option<GeneticNodeWrapper<T>>>) -> Result<(), Error> {
|
||||
if tree.val.is_none() {
|
||||
match (&mut tree.left, &mut tree.right) {
|
||||
(Some(l), Some(r)) => {
|
||||
Bracket::process_tree(&mut (*l))?;
|
||||
Bracket::process_tree(&mut (*r))?;
|
||||
|
||||
let left_node = (*l).val.as_ref().unwrap().data.as_ref().unwrap();
|
||||
let right_node = (*r).val.as_ref().unwrap().data.as_ref().unwrap();
|
||||
let merged_node = GeneticNode::merge(left_node, right_node)?;
|
||||
|
||||
tree.val = Some(GeneticNodeWrapper::from(*merged_node)?);
|
||||
tree.val.as_mut().unwrap().process_node(1)?;
|
||||
}
|
||||
(None, None) => {
|
||||
tree.val = Some(GeneticNodeWrapper::new()?);
|
||||
tree.val.as_mut().unwrap().process_node(1)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::Other(anyhow!("unable to process tree {:?}", tree)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_tree(&mut self) -> Result<(), Error> {
|
||||
fn process(&mut self) -> Result<(), Error> {
|
||||
Bracket::process_tree(self.tree.as_mut().unwrap())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -64,51 +119,52 @@ where T: GeneticNode + Serialize
|
|||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
pub struct Gemla<T>
|
||||
where T: GeneticNode + Serialize + DeserializeOwned
|
||||
where
|
||||
T: GeneticNode + Serialize + DeserializeOwned,
|
||||
{
|
||||
data: FileLinked<Bracket<T>>
|
||||
data: FileLinked<Bracket<T>>,
|
||||
}
|
||||
|
||||
impl<T> Gemla<T>
|
||||
impl<T> Gemla<T>
|
||||
where
|
||||
T: GeneticNode
|
||||
+ Serialize
|
||||
+ DeserializeOwned
|
||||
+ Default
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Default + Debug,
|
||||
{
|
||||
pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
|
||||
pub fn new(path: &Path, overwrite: bool) -> Result<Self, Error> {
|
||||
match File::open(path) {
|
||||
Ok(file) => {
|
||||
drop(file);
|
||||
|
||||
Ok(Gemla {
|
||||
data:
|
||||
if overwrite {
|
||||
FileLinked::from_file(path)?
|
||||
} else {
|
||||
FileLinked::new(Bracket {
|
||||
tree: btree!(None),
|
||||
iteration_scaling: IterationScaling::default()
|
||||
}, path)?
|
||||
}
|
||||
data: if overwrite {
|
||||
FileLinked::from_file(path)?
|
||||
} else {
|
||||
FileLinked::new(
|
||||
Bracket {
|
||||
tree: Some(btree!(None)),
|
||||
iteration_scaling: IterationScaling::default(),
|
||||
},
|
||||
path,
|
||||
)?
|
||||
},
|
||||
})
|
||||
},
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => {
|
||||
Ok(Gemla {
|
||||
data: FileLinked::new(Bracket {
|
||||
tree: btree!(None),
|
||||
iteration_scaling: IterationScaling::default()
|
||||
}, path)?
|
||||
})
|
||||
},
|
||||
Err(error) => Err(Error::IO(error))
|
||||
}
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||
data: FileLinked::new(
|
||||
Bracket {
|
||||
tree: Some(btree!(None)),
|
||||
iteration_scaling: IterationScaling::default(),
|
||||
},
|
||||
path,
|
||||
)?,
|
||||
}),
|
||||
Err(error) => Err(Error::IO(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||
self.data.mutate(|b| b.increase_height(steps as usize))??;
|
||||
self.data.mutate(|b| b.increase_height(steps))?;
|
||||
|
||||
self.data.mutate(|b| b.process_tree())??;
|
||||
self.data.mutate(|b| b.process())??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ impl<T> Tree<T> {
|
|||
/// btree!("ab"));
|
||||
/// assert_eq!(t.height(), 3);
|
||||
/// ```
|
||||
pub fn height(&self) -> u64 {
|
||||
pub fn height(&self) -> usize {
|
||||
match (self.left.as_ref(), self.right.as_ref()) {
|
||||
(Some(l), Some(r)) => max(l.height(), r.height()) + 1,
|
||||
(Some(l), None) => l.height() + 1,
|
||||
|
|
Loading…
Add table
Reference in a new issue