From e5188ec02ff62c82135abd74d742df12aa3d00e7 Mon Sep 17 00:00:00 2001 From: vandomej Date: Wed, 6 Oct 2021 00:08:05 -0700 Subject: [PATCH] Filling out function to process bracket tree --- file_linked/src/lib.rs | 24 +++--- gemla/src/bracket/genetic_node.rs | 16 +++- gemla/src/bracket/mod.rs | 128 +++++++++++++++++++++--------- gemla/src/tree/mod.rs | 2 +- 4 files changed, 122 insertions(+), 48 deletions(-) diff --git a/file_linked/src/lib.rs b/file_linked/src/lib.rs index ff33531..7d75e0f 100644 --- a/file_linked/src/lib.rs +++ b/file_linked/src/lib.rs @@ -2,12 +2,12 @@ extern crate serde; -use std::fs::File; -use std::io::prelude::*; -use std::path::PathBuf; use anyhow::Context; use serde::de::DeserializeOwned; use serde::Serialize; +use std::fs::File; +use std::io::prelude::*; +use std::path::{Path, PathBuf}; use thiserror::Error; #[derive(Error, Debug)] @@ -106,8 +106,11 @@ where /// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # } /// ``` - pub fn new(val: T, path: &PathBuf) -> Result, Error> { - let result = FileLinked { val, path: path.clone() }; + pub fn new(val: T, path: &Path) -> Result, Error> { + let result = FileLinked { + val, + path: path.to_path_buf(), + }; result.write_data()?; Ok(result) } @@ -279,14 +282,17 @@ where /// # Ok(()) /// # } /// ``` - pub fn from_file(path: &PathBuf) -> Result, Error> { - let file = File::open(path) - .with_context(|| format!("Unable to open file {}", path.display()))?; + pub fn from_file(path: &Path) -> Result, Error> { + let file = + File::open(path).with_context(|| format!("Unable to open file {}", path.display()))?; let val = serde_json::from_reader(file) .with_context(|| String::from("Unable to parse value from file."))?; - Ok(FileLinked { val, path: path.clone() }) + Ok(FileLinked { + val, + path: path.to_path_buf(), + }) } } diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index b6d5e20..e2ce60e 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -100,7 +100,7 @@ pub trait GeneticNode { /// # Ok(()) /// } /// } - /// + /// /// # impl Node { /// # fn get_fit_score(&self) -> f64 { /// # self.models @@ -187,7 +187,7 @@ pub trait GeneticNode { /// # } /// # } /// # - /// + /// /// impl GeneticNode for Node { /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node { @@ -429,6 +429,18 @@ where Ok(node) } + pub fn from(data: T) -> Result { + let mut node = GeneticNodeWrapper { + data: Some(data), + state: GeneticState::Initialize, + iteration: 0, + }; + + node.state = GeneticState::Simulate; + + Ok(node) + } + /// Performs state transitions on the [`GeneticNode`] wrapped by the [`GeneticNodeWrapper`]. /// Will loop through the node training and scoring process for the given number of `iterations`. /// diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 1774246..038b7ca 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -4,15 +4,17 @@ pub mod genetic_node; use crate::error::Error; -use crate::tree; -use genetic_node::{GeneticNodeWrapper, GeneticNode}; +use crate::tree::Tree; +use anyhow::anyhow; use file_linked::FileLinked; +use genetic_node::{GeneticNode, GeneticNodeWrapper}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::path::PathBuf; use std::fs::File; use std::io::ErrorKind; +use std::mem::replace; +use std::path::Path; /// As the bracket tree increases in height, `IterationScaling` can be used to configure the number of iterations that /// a node runs for. @@ -41,18 +43,71 @@ struct Bracket where T: GeneticNode + Serialize, { - pub tree: tree::Tree>>, + tree: Option>>>, iteration_scaling: IterationScaling, } impl Bracket -where T: GeneticNode + Serialize +where + T: GeneticNode + Serialize + Debug, { - fn increase_height(&mut self, _amount: usize) -> Result<(), Error> { + fn build_empty_tree(size: usize) -> Tree>> { + if size <= 1 { + btree!(None) + } else { + btree!( + None, + Bracket::build_empty_tree(size - 1), + Bracket::build_empty_tree(size - 1) + ) + } + } + + fn increase_height(&mut self, amount: u64) { + for _ in 0..amount { + let height = self.tree.as_ref().unwrap().height(); + let tree = replace(&mut self.tree, None); + drop(replace( + &mut self.tree, + Some(btree!( + None, + tree.unwrap(), + Bracket::build_empty_tree(height as usize) + )), + )); + } + } + + fn process_tree(tree: &mut Tree>>) -> Result<(), Error> { + if tree.val.is_none() { + match (&mut tree.left, &mut tree.right) { + (Some(l), Some(r)) => { + Bracket::process_tree(&mut (*l))?; + Bracket::process_tree(&mut (*r))?; + + let left_node = (*l).val.as_ref().unwrap().data.as_ref().unwrap(); + let right_node = (*r).val.as_ref().unwrap().data.as_ref().unwrap(); + let merged_node = GeneticNode::merge(left_node, right_node)?; + + tree.val = Some(GeneticNodeWrapper::from(*merged_node)?); + tree.val.as_mut().unwrap().process_node(1)?; + } + (None, None) => { + tree.val = Some(GeneticNodeWrapper::new()?); + tree.val.as_mut().unwrap().process_node(1)?; + } + _ => { + return Err(Error::Other(anyhow!("unable to process tree {:?}", tree))); + } + } + } + Ok(()) } - fn process_tree(&mut self) -> Result<(), Error> { + fn process(&mut self) -> Result<(), Error> { + Bracket::process_tree(self.tree.as_mut().unwrap())?; + Ok(()) } } @@ -64,51 +119,52 @@ where T: GeneticNode + Serialize /// /// [`GeneticNode`]: genetic_node::GeneticNode pub struct Gemla -where T: GeneticNode + Serialize + DeserializeOwned +where + T: GeneticNode + Serialize + DeserializeOwned, { - data: FileLinked> + data: FileLinked>, } -impl Gemla +impl Gemla where - T: GeneticNode - + Serialize - + DeserializeOwned - + Default + T: GeneticNode + Serialize + DeserializeOwned + Default + Debug, { - pub fn new(path: &PathBuf, overwrite: bool) -> Result { + pub fn new(path: &Path, overwrite: bool) -> Result { match File::open(path) { Ok(file) => { drop(file); Ok(Gemla { - data: - if overwrite { - FileLinked::from_file(path)? - } else { - FileLinked::new(Bracket { - tree: btree!(None), - iteration_scaling: IterationScaling::default() - }, path)? - } + data: if overwrite { + FileLinked::from_file(path)? + } else { + FileLinked::new( + Bracket { + tree: Some(btree!(None)), + iteration_scaling: IterationScaling::default(), + }, + path, + )? + }, }) - }, - Err(error) if error.kind() == ErrorKind::NotFound => { - Ok(Gemla { - data: FileLinked::new(Bracket { - tree: btree!(None), - iteration_scaling: IterationScaling::default() - }, path)? - }) - }, - Err(error) => Err(Error::IO(error)) + } + Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { + data: FileLinked::new( + Bracket { + tree: Some(btree!(None)), + iteration_scaling: IterationScaling::default(), + }, + path, + )?, + }), + Err(error) => Err(Error::IO(error)), } } pub fn simulate(&mut self, steps: u64) -> Result<(), Error> { - self.data.mutate(|b| b.increase_height(steps as usize))??; + self.data.mutate(|b| b.increase_height(steps))?; - self.data.mutate(|b| b.process_tree())??; + self.data.mutate(|b| b.process())??; Ok(()) } diff --git a/gemla/src/tree/mod.rs b/gemla/src/tree/mod.rs index 8c39d42..c1a2b39 100644 --- a/gemla/src/tree/mod.rs +++ b/gemla/src/tree/mod.rs @@ -131,7 +131,7 @@ impl Tree { /// btree!("ab")); /// assert_eq!(t.height(), 3); /// ``` - pub fn height(&self) -> u64 { + pub fn height(&self) -> usize { match (self.left.as_ref(), self.right.as_ref()) { (Some(l), Some(r)) => max(l.height(), r.height()) + 1, (Some(l), None) => l.height() + 1,