use gemla::bracket::genetic_node::GeneticNode; use gemla::error; use rand::prelude::*; use rand::thread_rng; use serde::{Deserialize, Serialize}; use std::convert::TryInto; const POPULATION_SIZE: u64 = 5; const POPULATION_REDUCTION_SIZE: u64 = 3; #[derive(Serialize, Deserialize, Debug)] pub struct TestState { pub population: Vec, } impl GeneticNode for TestState { fn initialize() -> Result, error::Error> { let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { population.push(thread_rng().gen_range(0..10000)) } Ok(Box::new(TestState { population })) } fn simulate(&mut self, iterations: u64) -> Result<(), error::Error> { let mut rng = thread_rng(); for _ in 0..iterations { self.population = self .population .clone() .iter() .map(|p| p + rng.gen_range(-10..10)) .collect() } Ok(()) } fn mutate(&mut self) -> Result<(), error::Error> { let mut rng = thread_rng(); let mut v = self.population.clone(); v.sort_unstable(); v.reverse(); self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); loop { if self.population.len() >= POPULATION_SIZE.try_into().unwrap() { break; } let new_individual_index = rng.gen_range(0..self.population.len()); let mut cross_breed_index = rng.gen_range(0..self.population.len()); loop { if new_individual_index != cross_breed_index { break; } cross_breed_index = rng.gen_range(0..self.population.len()); } let mut new_individual = self.population.clone()[new_individual_index]; let cross_breed = self.population.clone()[cross_breed_index]; new_individual += cross_breed + rng.gen_range(-10..10); self.population.push(new_individual); } Ok(()) } fn merge(left: &TestState, right: &TestState) -> Result, error::Error> { let mut v = left.population.clone(); v.append(&mut right.population.clone()); v.sort_by(|a, b| a.partial_cmp(b).unwrap()); v.reverse(); v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); let mut result = TestState { population: v }; result.mutate()?; Ok(Box::new(result)) } } #[cfg(test)] mod tests { use super::*; use gemla::bracket::genetic_node::GeneticNode; #[test] fn test_initialize() { let state = TestState::initialize().unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } #[test] fn test_simulate() { let mut state = TestState { population: vec![1, 1, 2, 3], }; let original_population = state.population.clone(); state.simulate(0).unwrap(); assert_eq!(original_population, state.population); state.simulate(1).unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 10 && b <= a + 10)); state.simulate(2).unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 30 && b <= a + 30)) } #[test] fn test_mutate() { let mut state = TestState { population: vec![4, 3, 3], }; state.mutate().unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } #[test] fn test_merge() { let state1 = TestState { population: vec![1, 2, 4, 5], }; let state2 = TestState { population: vec![0, 1, 3, 7], }; let merged_state = TestState::merge(&state1, &state2).unwrap(); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert!(merged_state.population.iter().any(|&x| x == 7)); assert!(merged_state.population.iter().any(|&x| x == 5)); assert!(merged_state.population.iter().any(|&x| x == 4)); } }