diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index 4b9d994..329a172 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -1,19 +1,20 @@ -use rand::prelude::*; -use rand::rngs::ThreadRng; use gemla::bracket::genetic_node::GeneticNode; use gemla::error; +use rand::prelude::*; +use rand::rngs::ThreadRng; use std::convert::TryInto; const POPULATION_SIZE: u64 = 5; +const POPULATION_REDUCTION_SIZE: u64 = 3; struct TestState { pub population: Vec, - thread_rng: ThreadRng + thread_rng: ThreadRng, } impl GeneticNode for TestState { fn initialize() -> Result, error::Error> { - let mut thread_rng = rand::thread_rng(); + let mut thread_rng = thread_rng(); let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { @@ -22,28 +23,38 @@ impl GeneticNode for TestState { Ok(Box::new(TestState { population, - thread_rng + thread_rng, })) } fn simulate(&mut self, iterations: u64) -> Result<(), error::Error> { for _ in 0..iterations { - self.population = self.population.clone().iter().map(|p| p + self.thread_rng.gen_range(-10.0..10.0)).collect() + self.population = self + .population + .clone() + .iter() + .map(|p| p + self.thread_rng.gen_range(-10.0..10.0)) + .collect() } Ok(()) } fn get_fit_score(&self) -> f64 { - self.population.clone().into_iter().reduce(f64::max).unwrap() + self.population + .clone() + .into_iter() + .reduce(f64::max) + .unwrap() } fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { let mut v = self.population.clone(); v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + v.reverse(); - self.population = v[4..].to_vec(); + self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); Ok(()) } @@ -56,7 +67,7 @@ impl GeneticNode for TestState { let new_individual_index = self.thread_rng.gen_range(0..self.population.len()); let mut cross_breed_index = self.thread_rng.gen_range(0..self.population.len()); - + loop { if new_individual_index != cross_breed_index { break; @@ -67,7 +78,7 @@ impl GeneticNode for TestState { let mut new_individual = self.population.clone()[new_individual_index]; let cross_breed = self.population.clone()[cross_breed_index]; - + new_individual += cross_breed + self.thread_rng.gen_range(-10.0..10.0); self.population.push(new_individual); @@ -75,4 +86,23 @@ impl GeneticNode for TestState { Ok(()) } -} \ No newline at end of file + + fn merge(left: &TestState, right: &TestState) -> Result, error::Error> { + let mut v = left.population.clone(); + v.append(&mut right.population.clone()); + + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + v.reverse(); + + v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); + + let mut result = TestState { + population: v, + thread_rng: thread_rng(), + }; + + result.mutate()?; + + Ok(Box::new(result)) + } +} diff --git a/gemla/src/bracket/genetic_node.rs b/gemla/src/bracket/genetic_node.rs index 3c9e4b4..9875d14 100644 --- a/gemla/src/bracket/genetic_node.rs +++ b/gemla/src/bracket/genetic_node.rs @@ -43,7 +43,7 @@ pub trait GeneticNode { /// } /// /// impl GeneticNode for Node { - /// fn initialize() -> Result, Error> { + /// fn initialize() -> Result, Error> { /// Ok(Box::new(Node {fit_score: 0.0})) /// } /// @@ -64,6 +64,10 @@ pub trait GeneticNode { /// # fn mutate(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {fit_score: 0.0})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { @@ -102,7 +106,7 @@ pub trait GeneticNode { /// } /// /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { + /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # } /// # @@ -129,6 +133,10 @@ pub trait GeneticNode { /// # fn mutate(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { @@ -167,7 +175,7 @@ pub trait GeneticNode { /// # } /// /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { + /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # } /// # @@ -195,6 +203,10 @@ pub trait GeneticNode { /// # fn mutate(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { @@ -233,7 +245,7 @@ pub trait GeneticNode { /// # } /// /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { + /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node { /// # models: vec![ /// # Model { fit_score: 0.0 }, @@ -276,6 +288,10 @@ pub trait GeneticNode { /// # fn mutate(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { @@ -328,7 +344,7 @@ pub trait GeneticNode { /// } /// /// impl GeneticNode for Node { - /// # fn initialize() -> Result, Error> { + /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node { /// # models: vec![ /// # Model { fit_score: 0.0 }, @@ -374,6 +390,10 @@ pub trait GeneticNode { /// } /// } /// } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { @@ -392,6 +412,8 @@ pub trait GeneticNode { /// # } /// ``` fn mutate(&mut self) -> Result<(), Error>; + + fn merge(left: &Self, right: &Self) -> Result, Error>; } /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as @@ -427,7 +449,7 @@ where /// /// impl GeneticNode for Node { /// //... - /// # fn initialize() -> Result, Error> { + /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(Node {fit_score: 0.0})) /// # } /// # @@ -447,6 +469,10 @@ where /// # fn mutate(&mut self) -> Result<(), Error> { /// # Ok(()) /// # } + /// # + /// # fn merge(left: &Node, right: &Node) -> Result, Error> { + /// # Ok(Box::new(Node {fit_score: 0.0})) + /// # } /// } /// /// # fn main() -> Result<(), Error> { diff --git a/gemla/src/bracket/mod.rs b/gemla/src/bracket/mod.rs index 29804ca..82ecbda 100644 --- a/gemla/src/bracket/mod.rs +++ b/gemla/src/bracket/mod.rs @@ -59,6 +59,10 @@ use std::path; /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(TestState { score: 0.0 })) /// # } +/// # +/// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { +/// # Ok(Box::new(left.clone())) +/// # } /// # } /// # /// # fn main() { @@ -176,6 +180,10 @@ where /// } /// /// //... + /// # + /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { + /// # Ok(Box::new(left.clone())) + /// # } /// } /// /// # fn main() { @@ -245,6 +253,10 @@ where /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(TestState { score: 0.0 })) /// # } + /// # + /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { + /// # Ok(Box::new(left.clone())) + /// # } /// # } /// # /// # fn main() { @@ -353,6 +365,10 @@ where /// # fn initialize() -> Result, Error> { /// # Ok(Box::new(TestState { score: 0.0 })) /// # } + /// # + /// # fn merge(left: &TestState, right: &TestState) -> Result, Error> { + /// # Ok(Box::new(left.clone())) + /// # } /// # } /// # /// # fn main() { @@ -442,9 +458,17 @@ mod tests { Ok(()) } - fn initialize() -> Result, Error> { + fn initialize() -> Result, Error> { Ok(Box::new(TestState { score: 0.0 })) } + + fn merge(left: &TestState, right: &TestState) -> Result, Error> { + Ok(Box::new(if left.get_fit_score() > right.get_fit_score() { + left.clone() + } else { + right.clone() + })) + } } #[test]