Filling out function to process bracket tree

This commit is contained in:
vandomej 2021-10-06 00:08:05 -07:00
parent 569a17f145
commit e5188ec02f
4 changed files with 122 additions and 48 deletions

View file

@ -2,12 +2,12 @@
extern crate serde;
use std::fs::File;
use std::io::prelude::*;
use std::path::PathBuf;
use anyhow::Context;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fs::File;
use std::io::prelude::*;
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Error, Debug)]
@ -106,8 +106,11 @@ where
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn new(val: T, path: &PathBuf) -> Result<FileLinked<T>, Error> {
let result = FileLinked { val, path: path.clone() };
pub fn new(val: T, path: &Path) -> Result<FileLinked<T>, Error> {
let result = FileLinked {
val,
path: path.to_path_buf(),
};
result.write_data()?;
Ok(result)
}
@ -279,14 +282,17 @@ where
/// # Ok(())
/// # }
/// ```
pub fn from_file(path: &PathBuf) -> Result<FileLinked<T>, Error> {
let file = File::open(path)
.with_context(|| format!("Unable to open file {}", path.display()))?;
pub fn from_file(path: &Path) -> Result<FileLinked<T>, Error> {
let file =
File::open(path).with_context(|| format!("Unable to open file {}", path.display()))?;
let val = serde_json::from_reader(file)
.with_context(|| String::from("Unable to parse value from file."))?;
Ok(FileLinked { val, path: path.clone() })
Ok(FileLinked {
val,
path: path.to_path_buf(),
})
}
}

View file

@ -100,7 +100,7 @@ pub trait GeneticNode {
/// # Ok(())
/// }
/// }
///
///
/// # impl Node {
/// # fn get_fit_score(&self) -> f64 {
/// # self.models
@ -187,7 +187,7 @@ pub trait GeneticNode {
/// # }
/// # }
/// #
///
///
/// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {
@ -429,6 +429,18 @@ where
Ok(node)
}
pub fn from(data: T) -> Result<Self, Error> {
let mut node = GeneticNodeWrapper {
data: Some(data),
state: GeneticState::Initialize,
iteration: 0,
};
node.state = GeneticState::Simulate;
Ok(node)
}
/// Performs state transitions on the [`GeneticNode`] wrapped by the [`GeneticNodeWrapper`].
/// Will loop through the node training and scoring process for the given number of `iterations`.
///

View file

@ -4,15 +4,17 @@
pub mod genetic_node;
use crate::error::Error;
use crate::tree;
use genetic_node::{GeneticNodeWrapper, GeneticNode};
use crate::tree::Tree;
use anyhow::anyhow;
use file_linked::FileLinked;
use genetic_node::{GeneticNode, GeneticNodeWrapper};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::path::PathBuf;
use std::fs::File;
use std::io::ErrorKind;
use std::mem::replace;
use std::path::Path;
/// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that
/// a node runs for.
@ -41,18 +43,71 @@ struct Bracket<T>
where
T: GeneticNode + Serialize,
{
pub tree: tree::Tree<Option<GeneticNodeWrapper<T>>>,
tree: Option<Tree<Option<GeneticNodeWrapper<T>>>>,
iteration_scaling: IterationScaling,
}
impl<T> Bracket<T>
where T: GeneticNode + Serialize
where
T: GeneticNode + Serialize + Debug,
{
fn increase_height(&mut self, _amount: usize) -> Result<(), Error> {
fn build_empty_tree(size: usize) -> Tree<Option<GeneticNodeWrapper<T>>> {
if size <= 1 {
btree!(None)
} else {
btree!(
None,
Bracket::build_empty_tree(size - 1),
Bracket::build_empty_tree(size - 1)
)
}
}
fn increase_height(&mut self, amount: u64) {
for _ in 0..amount {
let height = self.tree.as_ref().unwrap().height();
let tree = replace(&mut self.tree, None);
drop(replace(
&mut self.tree,
Some(btree!(
None,
tree.unwrap(),
Bracket::build_empty_tree(height as usize)
)),
));
}
}
fn process_tree(tree: &mut Tree<Option<GeneticNodeWrapper<T>>>) -> Result<(), Error> {
if tree.val.is_none() {
match (&mut tree.left, &mut tree.right) {
(Some(l), Some(r)) => {
Bracket::process_tree(&mut (*l))?;
Bracket::process_tree(&mut (*r))?;
let left_node = (*l).val.as_ref().unwrap().data.as_ref().unwrap();
let right_node = (*r).val.as_ref().unwrap().data.as_ref().unwrap();
let merged_node = GeneticNode::merge(left_node, right_node)?;
tree.val = Some(GeneticNodeWrapper::from(*merged_node)?);
tree.val.as_mut().unwrap().process_node(1)?;
}
(None, None) => {
tree.val = Some(GeneticNodeWrapper::new()?);
tree.val.as_mut().unwrap().process_node(1)?;
}
_ => {
return Err(Error::Other(anyhow!("unable to process tree {:?}", tree)));
}
}
}
Ok(())
}
fn process_tree(&mut self) -> Result<(), Error> {
fn process(&mut self) -> Result<(), Error> {
Bracket::process_tree(self.tree.as_mut().unwrap())?;
Ok(())
}
}
@ -64,51 +119,52 @@ where T: GeneticNode + Serialize
///
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T>
where T: GeneticNode + Serialize + DeserializeOwned
where
T: GeneticNode + Serialize + DeserializeOwned,
{
data: FileLinked<Bracket<T>>
data: FileLinked<Bracket<T>>,
}
impl<T> Gemla<T>
impl<T> Gemla<T>
where
T: GeneticNode
+ Serialize
+ DeserializeOwned
+ Default
T: GeneticNode + Serialize + DeserializeOwned + Default + Debug,
{
pub fn new(path: &PathBuf, overwrite: bool) -> Result<Self, Error> {
pub fn new(path: &Path, overwrite: bool) -> Result<Self, Error> {
match File::open(path) {
Ok(file) => {
drop(file);
Ok(Gemla {
data:
if overwrite {
FileLinked::from_file(path)?
} else {
FileLinked::new(Bracket {
tree: btree!(None),
iteration_scaling: IterationScaling::default()
}, path)?
}
data: if overwrite {
FileLinked::from_file(path)?
} else {
FileLinked::new(
Bracket {
tree: Some(btree!(None)),
iteration_scaling: IterationScaling::default(),
},
path,
)?
},
})
},
Err(error) if error.kind() == ErrorKind::NotFound => {
Ok(Gemla {
data: FileLinked::new(Bracket {
tree: btree!(None),
iteration_scaling: IterationScaling::default()
}, path)?
})
},
Err(error) => Err(Error::IO(error))
}
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new(
Bracket {
tree: Some(btree!(None)),
iteration_scaling: IterationScaling::default(),
},
path,
)?,
}),
Err(error) => Err(Error::IO(error)),
}
}
pub fn simulate(&mut self, steps: u64) -> Result<(), Error> {
self.data.mutate(|b| b.increase_height(steps as usize))??;
self.data.mutate(|b| b.increase_height(steps))?;
self.data.mutate(|b| b.process_tree())??;
self.data.mutate(|b| b.process())??;
Ok(())
}

View file

@ -131,7 +131,7 @@ impl<T> Tree<T> {
/// btree!("ab"));
/// assert_eq!(t.height(), 3);
/// ```
pub fn height(&self) -> u64 {
pub fn height(&self) -> usize {
match (self.left.as_ref(), self.right.as_ref()) {
(Some(l), Some(r)) => max(l.height(), r.height()) + 1,
(Some(l), None) => l.height() + 1,