Adding merge function to TestState

This commit is contained in:
vandomej 2021-10-05 00:14:19 -07:00
parent 6d29508b83
commit 87ade09e36
3 changed files with 98 additions and 18 deletions

View file

@ -1,19 +1,20 @@
use rand::prelude::*;
use rand::rngs::ThreadRng;
use gemla::bracket::genetic_node::GeneticNode; use gemla::bracket::genetic_node::GeneticNode;
use gemla::error; use gemla::error;
use rand::prelude::*;
use rand::rngs::ThreadRng;
use std::convert::TryInto; use std::convert::TryInto;
const POPULATION_SIZE: u64 = 5; const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;
struct TestState { struct TestState {
pub population: Vec<f64>, pub population: Vec<f64>,
thread_rng: ThreadRng thread_rng: ThreadRng,
} }
impl GeneticNode for TestState { impl GeneticNode for TestState {
fn initialize() -> Result<Box<Self>, error::Error> { fn initialize() -> Result<Box<Self>, error::Error> {
let mut thread_rng = rand::thread_rng(); let mut thread_rng = thread_rng();
let mut population: Vec<f64> = vec![]; let mut population: Vec<f64> = vec![];
for _ in 0..POPULATION_SIZE { for _ in 0..POPULATION_SIZE {
@ -22,28 +23,38 @@ impl GeneticNode for TestState {
Ok(Box::new(TestState { Ok(Box::new(TestState {
population, population,
thread_rng thread_rng,
})) }))
} }
fn simulate(&mut self, iterations: u64) -> Result<(), error::Error> { fn simulate(&mut self, iterations: u64) -> Result<(), error::Error> {
for _ in 0..iterations { for _ in 0..iterations {
self.population = self.population.clone().iter().map(|p| p + self.thread_rng.gen_range(-10.0..10.0)).collect() self.population = self
.population
.clone()
.iter()
.map(|p| p + self.thread_rng.gen_range(-10.0..10.0))
.collect()
} }
Ok(()) Ok(())
} }
fn get_fit_score(&self) -> f64 { fn get_fit_score(&self) -> f64 {
self.population.clone().into_iter().reduce(f64::max).unwrap() self.population
.clone()
.into_iter()
.reduce(f64::max)
.unwrap()
} }
fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> { fn calculate_scores_and_trim(&mut self) -> Result<(), error::Error> {
let mut v = self.population.clone(); let mut v = self.population.clone();
v.sort_by(|a, b| a.partial_cmp(b).unwrap()); v.sort_by(|a, b| a.partial_cmp(b).unwrap());
v.reverse();
self.population = v[4..].to_vec(); self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
Ok(()) Ok(())
} }
@ -75,4 +86,23 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, error::Error> {
let mut v = left.population.clone();
v.append(&mut right.population.clone());
v.sort_by(|a, b| a.partial_cmp(b).unwrap());
v.reverse();
v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
let mut result = TestState {
population: v,
thread_rng: thread_rng(),
};
result.mutate()?;
Ok(Box::new(result))
}
} }

View file

@ -43,7 +43,7 @@ pub trait GeneticNode {
/// } /// }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// fn initialize() -> Result<Box<Self>, Error> { /// fn initialize() -> Result<Box<Node>, Error> {
/// Ok(Box::new(Node {fit_score: 0.0})) /// Ok(Box::new(Node {fit_score: 0.0}))
/// } /// }
/// ///
@ -64,6 +64,10 @@ pub trait GeneticNode {
/// # fn mutate(&mut self) -> Result<(), Error> { /// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
@ -102,7 +106,7 @@ pub trait GeneticNode {
/// } /// }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # } /// # }
/// # /// #
@ -129,6 +133,10 @@ pub trait GeneticNode {
/// # fn mutate(&mut self) -> Result<(), Error> { /// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
@ -167,7 +175,7 @@ pub trait GeneticNode {
/// # } /// # }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]})) /// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # } /// # }
/// # /// #
@ -195,6 +203,10 @@ pub trait GeneticNode {
/// # fn mutate(&mut self) -> Result<(), Error> { /// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}]}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
@ -233,7 +245,7 @@ pub trait GeneticNode {
/// # } /// # }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node { /// # Ok(Box::new(Node {
/// # models: vec![ /// # models: vec![
/// # Model { fit_score: 0.0 }, /// # Model { fit_score: 0.0 },
@ -276,6 +288,10 @@ pub trait GeneticNode {
/// # fn mutate(&mut self) -> Result<(), Error> { /// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
@ -328,7 +344,7 @@ pub trait GeneticNode {
/// } /// }
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node { /// # Ok(Box::new(Node {
/// # models: vec![ /// # models: vec![
/// # Model { fit_score: 0.0 }, /// # Model { fit_score: 0.0 },
@ -374,6 +390,10 @@ pub trait GeneticNode {
/// } /// }
/// } /// }
/// } /// }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {models: vec![Model {fit_score: 0.0}], population_size: 1}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {
@ -392,6 +412,8 @@ pub trait GeneticNode {
/// # } /// # }
/// ``` /// ```
fn mutate(&mut self) -> Result<(), Error>; fn mutate(&mut self) -> Result<(), Error>;
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
} }
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
@ -427,7 +449,7 @@ where
/// ///
/// impl GeneticNode for Node { /// impl GeneticNode for Node {
/// //... /// //...
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0})) /// # Ok(Box::new(Node {fit_score: 0.0}))
/// # } /// # }
/// # /// #
@ -447,6 +469,10 @@ where
/// # fn mutate(&mut self) -> Result<(), Error> { /// # fn mutate(&mut self) -> Result<(), Error> {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// #
/// # fn merge(left: &Node, right: &Node) -> Result<Box<Node>, Error> {
/// # Ok(Box::new(Node {fit_score: 0.0}))
/// # }
/// } /// }
/// ///
/// # fn main() -> Result<(), Error> { /// # fn main() -> Result<(), Error> {

View file

@ -59,6 +59,10 @@ use std::path;
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 })) /// # Ok(Box::new(TestState { score: 0.0 }))
/// # } /// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # } /// # }
/// # /// #
/// # fn main() { /// # fn main() {
@ -176,6 +180,10 @@ where
/// } /// }
/// ///
/// //... /// //...
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// } /// }
/// ///
/// # fn main() { /// # fn main() {
@ -245,6 +253,10 @@ where
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 })) /// # Ok(Box::new(TestState { score: 0.0 }))
/// # } /// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # } /// # }
/// # /// #
/// # fn main() { /// # fn main() {
@ -353,6 +365,10 @@ where
/// # fn initialize() -> Result<Box<Self>, Error> { /// # fn initialize() -> Result<Box<Self>, Error> {
/// # Ok(Box::new(TestState { score: 0.0 })) /// # Ok(Box::new(TestState { score: 0.0 }))
/// # } /// # }
/// #
/// # fn merge(left: &TestState, right: &TestState) -> Result<Box<Self>, Error> {
/// # Ok(Box::new(left.clone()))
/// # }
/// # } /// # }
/// # /// #
/// # fn main() { /// # fn main() {
@ -442,9 +458,17 @@ mod tests {
Ok(()) Ok(())
} }
fn initialize() -> Result<Box<Self>, Error> { fn initialize() -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 })) Ok(Box::new(TestState { score: 0.0 }))
} }
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.get_fit_score() > right.get_fit_score() {
left.clone()
} else {
right.clone()
}))
}
} }
#[test] #[test]