From 9081fb0b3c650bdb1c36760acfee5ed3d0dbdc03 Mon Sep 17 00:00:00 2001 From: vandomej Date: Tue, 5 Oct 2021 10:29:48 -0700 Subject: [PATCH] Adding unit tests for test states --- gemla/src/bin/test_state/mod.rs | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index 329a172..eb404e1 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -106,3 +106,98 @@ impl GeneticNode for TestState { Ok(Box::new(result)) } } + +#[cfg(test)] +mod tests { + use super::*; + use gemla::bracket::genetic_node::GeneticNode; + + #[test] + fn test_initialize() { + let state = TestState::initialize().unwrap(); + + assert_eq!(state.population.len(), POPULATION_SIZE as usize); + } + + #[test] + fn test_simulate() { + let mut state = TestState { + thread_rng: thread_rng(), + population: vec![1.0, 1.0, 2.0, 3.0], + }; + + let original_population = state.population.clone(); + + state.simulate(0).unwrap(); + assert_eq!(original_population, state.population); + + state.simulate(1).unwrap(); + assert!(original_population + .iter() + .zip(state.population.iter()) + .all(|(&a, &b)| b >= a - 10.0 && b <= a + 10.0)); + + state.simulate(2).unwrap(); + assert!(original_population + .iter() + .zip(state.population.iter()) + .all(|(&a, &b)| b >= a - 30.0 && b <= a + 30.0)) + } + + #[test] + fn test_get_fit_score() { + let state = TestState { + thread_rng: thread_rng(), + population: vec![1.0, 1.0, 2.0, 3.0], + }; + + assert_eq!(state.get_fit_score(), 3.0); + } + + #[test] + fn test_calculate_scores_and_trim() { + let mut state = TestState { + thread_rng: thread_rng(), + population: vec![4.0, 1.0, 1.0, 3.0, 2.0], + }; + + state.calculate_scores_and_trim().unwrap(); + + assert_eq!(state.population.len(), POPULATION_REDUCTION_SIZE as usize); + assert!(state.population.iter().any(|&x| x == 4.0)); + assert!(state.population.iter().any(|&x| x == 3.0)); + assert!(state.population.iter().any(|&x| x == 2.0)); + } + + #[test] + fn test_mutate() { + let mut state = TestState { + thread_rng: thread_rng(), + population: vec![4.0, 3.0, 3.0], + }; + + state.mutate().unwrap(); + + assert_eq!(state.population.len(), POPULATION_SIZE as usize); + } + + #[test] + fn test_merge() { + let state1 = TestState { + thread_rng: thread_rng(), + population: vec![1.0, 2.0, 4.0, 5.0], + }; + + let state2 = TestState { + thread_rng: thread_rng(), + population: vec![0.0, 1.0, 3.0, 7.0], + }; + + let merged_state = TestState::merge(&state1, &state2).unwrap(); + + assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); + assert!(merged_state.population.iter().any(|&x| x == 7.0)); + assert!(merged_state.population.iter().any(|&x| x == 5.0)); + assert!(merged_state.population.iter().any(|&x| x == 4.0)); + } +}